diff --git a/.dockerignore b/.dockerignore index ee2e8af78..45174e861 100644 --- a/.dockerignore +++ b/.dockerignore @@ -47,3 +47,4 @@ contracts/.git !etc/env/consensus_secrets.yaml !etc/env/consensus_config.yaml !rust-toolchain +!patches/ \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index ced1b4bf2..8a14a36a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -81,7 +81,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" dependencies = [ "cfg-if 1.0.0", - "getrandom", "once_cell", "version_check", "zerocopy", @@ -203,7 +202,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0609c78bd572f4edc74310dfb63a01f5609d53fa8b4dd7c4d98aef3b3e8d72d1" dependencies = [ "proc-macro-hack", - "quote 1.0.33", + "quote 1.0.37", "syn 1.0.109", ] @@ -257,9 +256,9 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -268,9 +267,9 @@ version = "0.1.74" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -289,13 +288,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] -name = "atomic-write-file" -version = "0.1.2" +name = "attohttpc" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edcdbedc2236483ab103a53415653d6b4442ea6141baf1ffa85df29635e88436" +checksum = "0f77d243921b0979fbbd728dd2d5162e68ac8252976797c24eb5b3a6af9090dc" dependencies = [ - "nix", - "rand 0.8.5", + "http 0.2.9", + "log", + "native-tls", + "serde", + "serde_json", + "url", ] [[package]] @@ -315,6 +318,32 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "aws-creds" +version = "0.36.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "390ad3b77f3e21e01a4a0355865853b681daf1988510b0b15e31c0c4ae7eb0f6" +dependencies = [ + "attohttpc", + "home", + "log", + "quick-xml", + "rust-ini", + "serde", + "thiserror", + "time", + "url", +] + +[[package]] +name = "aws-region" +version = "0.25.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42fed2b9fca70f2908268d057a607f2a906f47edbf856ea8587de9038d264e22" +dependencies = [ + "thiserror", +] + [[package]] name = "axum" version = "0.6.20" @@ -521,10 +550,12 @@ dependencies = [ [[package]] name = "bigdecimal" -version = "0.3.1" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6773ddc0eafc0e509fb60e48dff7f450f8e674a0686ae8605e8d9901bd5eefa" +checksum = "51d712318a27c7150326677b321a5fa91b55f6d9034ffd67f20319e147d40cee" dependencies = [ + "autocfg", + "libm", "num-bigint 0.4.4", "num-integer", "num-traits", @@ -552,12 +583,12 @@ dependencies = [ "lazycell", "peeking_take_while", "prettyplease", - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "regex", "rustc-hash", "shlex", - "syn 2.0.38", + "syn 2.0.89", ] [[package]] @@ -769,9 +800,9 @@ checksum = "bf4918709cc4dd777ad2b6303ed03cb37f3ca0ccede8c1b0d28ac6db8f4710e0" dependencies = [ "once_cell", "proc-macro-crate 2.0.0", - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", "syn_derive", ] @@ -810,8 +841,8 @@ version = "0.6.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7ec4c6f261935ad534c0c22dbef2201b45918860eb1c574b972bd213a76af61" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "syn 1.0.109", ] @@ -889,12 +920,13 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.83" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ "jobserver", "libc", + "shlex", ] [[package]] @@ -956,9 +988,9 @@ dependencies = [ [[package]] name = "chrono" -version = "0.4.31" +version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" dependencies = [ "android-tzdata", "iana-time-zone", @@ -966,7 +998,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -1181,9 +1213,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0862016ff20d69b84ef8247369fabf5c008a7417002411897d40ee1f4532b873" dependencies = [ "heck 0.4.1", - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -1248,6 +1280,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bed69047ed42e52c7e38d6421eeb8ceefb4f2a2b52eed59137f7bad7908f6800" +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils 0.8.16", +] + [[package]] name = "console" version = "0.15.7" @@ -1266,6 +1307,26 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28c122c3980598d243d63d9a704629a2d748d101f278052ff068be5a4423ab6f" +[[package]] +name = "const-random" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom", + "once_cell", + "tiny-keccak 2.0.2", +] + [[package]] name = "const_format" version = "0.2.32" @@ -1281,8 +1342,8 @@ version = "0.2.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7f6ff08fd20f4f299298a28e2dfa8a8ba1036e6cd2460ac1de7b425d76f2500" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "unicode-xid 0.2.4", ] @@ -1565,8 +1626,8 @@ version = "0.1.0" source = "git+https://github.com/matter-labs/era-boojum.git?branch=main#4bcb11f0610302110ae8109af01d5b652191b2f6" dependencies = [ "proc-macro-error", - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "syn 1.0.109", ] @@ -1611,9 +1672,9 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -1634,8 +1695,8 @@ checksum = "859d65a907b6852c9361e3185c862aae7fafd2887876799fa55f5f99dc40d610" dependencies = [ "fnv", "ident_case", - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "strsim 0.10.0", "syn 1.0.109", ] @@ -1647,7 +1708,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c972679f83bdf9c42bd905396b6c3588a843a17f0f16dfcfa3e2c5d57441835" dependencies = [ "darling_core", - "quote 1.0.33", + "quote 1.0.37", "syn 1.0.109", ] @@ -1658,7 +1719,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" dependencies = [ "cfg-if 1.0.0", - "hashbrown 0.14.2", + "hashbrown 0.14.5", "lock_api", "once_cell", "parking_lot_core", @@ -1711,8 +1772,8 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "syn 1.0.109", ] @@ -1731,9 +1792,9 @@ version = "1.0.0-beta.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2bba3e9872d7c58ce7ef0fcf1844fcc3e23ef2a58377b50df35dd98e42a5726e" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", "unicode-xid 0.2.4", ] @@ -1764,6 +1825,15 @@ dependencies = [ "subtle", ] +[[package]] +name = "dlv-list" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "442039f5147480ba31067cb00ada1adae6892028e40e45fc5de7b7df6dcc1b5f" +dependencies = [ + "const-random", +] + [[package]] name = "dotenvy" version = "0.15.7" @@ -1935,12 +2005,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.5" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3e13f66a2f95e32a39eaa81f6b95d42878ca0e1db0c7543723dfe12557e860" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" dependencies = [ "libc", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -2009,15 +2079,20 @@ dependencies = [ [[package]] name = "event-listener" -version = "2.5.3" +version = "5.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" +checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] [[package]] name = "fastrand" -version = "2.0.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" +checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" [[package]] name = "ff" @@ -2061,8 +2136,8 @@ dependencies = [ "num-bigint 0.4.4", "num-integer", "num-traits", - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "serde", "syn 1.0.109", ] @@ -2280,9 +2355,9 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -2625,9 +2700,15 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.2" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93e7192158dbcda357bdec5fb5788eebf8bbac027f3f33e719d29135ae84156" +checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash 0.8.7", "allocator-api2", @@ -2635,11 +2716,11 @@ dependencies = [ [[package]] name = "hashlink" -version = "0.8.4" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" dependencies = [ - "hashbrown 0.14.2", + "hashbrown 0.14.5", ] [[package]] @@ -2656,9 +2737,6 @@ name = "heck" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" -dependencies = [ - "unicode-segmentation", -] [[package]] name = "heck" @@ -3009,8 +3087,8 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "11d7a9f6330b71fea57921c9b61c47ee6e84f72d394754eff6163ae67e7395eb" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "syn 1.0.109", ] @@ -3031,7 +3109,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" dependencies = [ "equivalent", - "hashbrown 0.14.2", + "hashbrown 0.14.5", ] [[package]] @@ -3118,9 +3196,9 @@ checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" [[package]] name = "jobserver" -version = "0.1.27" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" dependencies = [ "libc", ] @@ -3239,9 +3317,9 @@ checksum = "7895f186d5921065d96e16bd795e5ca89ac8356ec423fafc6e3d7cf8ec11aee4" dependencies = [ "heck 0.5.0", "proc-macro-crate 3.1.0", - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -3398,9 +3476,9 @@ checksum = "884e2677b40cc8c339eaefcb701c32ef1fd2493d71118dc0ca4b6a736c93bd67" [[package]] name = "libc" -version = "0.2.149" +version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" [[package]] name = "libloading" @@ -3436,9 +3514,9 @@ dependencies = [ [[package]] name = "libsqlite3-sys" -version = "0.27.0" +version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4e226dcd58b4be396f7bd3c20da8fdee2911400705297ba7d2d7cc2c30f716" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" dependencies = [ "cc", "pkg-config", @@ -3477,16 +3555,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba125974b109d512fccbc6c0244e7580143e460895dfd6ea7f8bbb692fd94396" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] name = "linux-raw-sys" -version = "0.4.10" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" [[package]] name = "loadnext" @@ -3553,10 +3631,10 @@ checksum = "dc487311295e0002e452025d6b580b77bb17286de87b57138f3b5db711cded68" dependencies = [ "beef", "fnv", - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "regex-syntax 0.6.29", - "syn 2.0.38", + "syn 2.0.89", ] [[package]] @@ -3614,6 +3692,17 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "maybe-async" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cf92c10c7e361d6b99666ec1c6f9805b0bea2c3bd8c78dc6fe98ac5bd78db11" +dependencies = [ + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", +] + [[package]] name = "maybe-uninit" version = "2.0.0" @@ -3630,6 +3719,12 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.6.4" @@ -3687,9 +3782,9 @@ version = "5.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49e7bc1560b95a3c4a25d03de42fe76ca718ab92d1a22a55b9b4cf67b3ae635c" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -3723,6 +3818,15 @@ dependencies = [ "triomphe", ] +[[package]] +name = "minidom" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f45614075738ce1b77a1768912a60c0227525971b03e09122a05b8a34a2a6278" +dependencies = [ + "rxml", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -3985,10 +4089,10 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96667db765a921f7b295ffee8b60472b686a51d4f21c2ee4ffdb94c7013b65a6" dependencies = [ - "proc-macro-crate 1.3.1", - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro-crate 1.1.3", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -3998,9 +4102,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "681030a937600a36906c185595136d26abfebb4aa9c65701cefcaf8578bb982b" dependencies = [ "proc-macro-crate 3.1.0", - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -4032,9 +4136,9 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" [[package]] name = "openssl" -version = "0.10.57" +version = "0.10.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bac25ee399abb46215765b1cb35bc0212377e58a061560d8b29b024fd0430e7c" +checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" dependencies = [ "bitflags 2.6.0", "cfg-if 1.0.0", @@ -4051,9 +4155,9 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -4064,9 +4168,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.93" +version = "0.9.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db4d56a4c0478783083cfafcc42493dd4a981d41669da64b4572a2a089b51b1d" +checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" dependencies = [ "cc", "libc", @@ -4196,6 +4300,16 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ordered-multimap" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ed8acf08e98e744e5384c8bc63ceb0364e68a6854187221c18df61c4797690e" +dependencies = [ + "dlv-list", + "hashbrown 0.13.2", +] + [[package]] name = "os_info" version = "3.7.0" @@ -4288,12 +4402,18 @@ version = "3.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "312270ee71e1cd70289dacf597cab7b207aa107d2f28191c2ae45b2ece18a260" dependencies = [ - "proc-macro-crate 1.3.1", - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro-crate 1.1.3", + "proc-macro2 1.0.92", + "quote 1.0.37", "syn 1.0.109", ] +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + [[package]] name = "parking_lot" version = "0.12.1" @@ -4312,7 +4432,7 @@ checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" dependencies = [ "cfg-if 1.0.0", "libc", - "redox_syscall 0.4.1", + "redox_syscall", "smallvec", "windows-targets 0.48.5", ] @@ -4383,9 +4503,9 @@ checksum = "2a31940305ffc96863a735bef7c7994a00b325a7138fdbc5bda0f1a0476d3275" dependencies = [ "pest", "pest_meta", - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -4424,9 +4544,9 @@ version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -4557,8 +4677,8 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d" dependencies = [ - "proc-macro2 1.0.69", - "syn 2.0.38", + "proc-macro2 1.0.92", + "syn 2.0.89", ] [[package]] @@ -4585,12 +4705,12 @@ dependencies = [ [[package]] name = "proc-macro-crate" -version = "1.3.1" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f4c021e1093a56626774e81216a4ce732a735e5bad4868a03f3ed65ca0c3919" +checksum = "e17d47ce914bf4de440332250b0edd23ce48c005f59fab39d3335866b114f11a" dependencies = [ - "once_cell", - "toml_edit 0.19.15", + "thiserror", + "toml", ] [[package]] @@ -4618,8 +4738,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" dependencies = [ "proc-macro-error-attr", - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "syn 1.0.109", "version_check", ] @@ -4630,8 +4750,8 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "version_check", ] @@ -4652,9 +4772,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.69" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] @@ -4677,9 +4797,9 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "440f724eba9f6996b75d63681b0a92b06947f1457076d503a4d2e2c8f56442b8" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -4719,7 +4839,7 @@ dependencies = [ "prost 0.12.1", "prost-types", "regex", - "syn 2.0.38", + "syn 2.0.89", "tempfile", "which", ] @@ -4732,8 +4852,8 @@ checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" dependencies = [ "anyhow", "itertools 0.10.5", - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "syn 1.0.109", ] @@ -4745,9 +4865,9 @@ checksum = "265baba7fabd416cf5078179f7d2cbeca4ce7a9041111900675ea7c4cb8a4c32" dependencies = [ "anyhow", "itertools 0.10.5", - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -4817,8 +4937,8 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "16b845dbfca988fa33db069c0e230574d15a3088f147a87b64c7589eb662c9ac" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "syn 1.0.109", ] @@ -4858,6 +4978,16 @@ dependencies = [ "byteorder", ] +[[package]] +name = "quick-xml" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eff6510e86862b57b210fd8cbe8ed3f0d7d600b9c2863cd4549a2e033c66e956" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "quote" version = "0.6.13" @@ -4869,11 +4999,11 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.33" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ - "proc-macro2 1.0.69", + "proc-macro2 1.0.92", ] [[package]] @@ -4987,15 +5117,6 @@ dependencies = [ "rand_core 0.3.1", ] -[[package]] -name = "redox_syscall" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" -dependencies = [ - "bitflags 1.3.2", -] - [[package]] name = "redox_syscall" version = "0.4.1" @@ -5237,8 +5358,8 @@ version = "0.7.43" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5c462a1328c8e67e4d6dbad1eb0355dd43e8ab432c6e227a43657f16ade5033" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "syn 1.0.109", ] @@ -5288,6 +5409,53 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rust-ini" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e2a3bcec1f113553ef1c88aae6c020a369d03d55b58de9869a0908930385091" +dependencies = [ + "cfg-if 1.0.0", + "ordered-multimap", +] + +[[package]] +name = "rust-s3" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6679da8efaf4c6f0c161de0961dfe95fb6e9049c398d6fbdada2639f053aedb" +dependencies = [ + "async-trait", + "aws-creds", + "aws-region", + "base64 0.21.5", + "bytes", + "cfg-if 1.0.0", + "futures 0.3.28", + "hex", + "hmac", + "http 0.2.9", + "hyper 0.14.29", + "hyper-tls 0.5.0", + "log", + "maybe-async", + "md5", + "minidom", + "native-tls", + "percent-encoding", + "quick-xml", + "serde", + "serde_derive", + "serde_json", + "sha2 0.10.8", + "thiserror", + "time", + "tokio", + "tokio-native-tls", + "tokio-stream", + "url", +] + [[package]] name = "rust_decimal" version = "1.33.1" @@ -5333,15 +5501,15 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.20" +version = "0.38.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67ce50cb2e16c2903e30d1cbccfd8387a74b9d4c938b6a4c5ec6cc7556f7a8a0" +checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" dependencies = [ "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -5446,6 +5614,23 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" +[[package]] +name = "rxml" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a98f186c7a2f3abbffb802984b7f1dfd65dac8be1aafdaabbca4137f53f0dff7" +dependencies = [ + "bytes", + "rxml_validation", + "smartstring", +] + +[[package]] +name = "rxml_validation" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22a197350ece202f19a166d1ad6d9d6de145e1d2a8ef47db299abe164dbd7530" + [[package]] name = "ryu" version = "1.0.15" @@ -5715,9 +5900,9 @@ version = "1.0.189" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e48d1f918009ce3145511378cf68d613e3b3d9137d67272562080d68a2b32d5" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -5772,8 +5957,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e182d6ec6f05393cc0e5ed1bf81ad6db3a8feedf8ee515ecdd369809bcce8082" dependencies = [ "darling", - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "syn 1.0.109", ] @@ -5977,6 +6162,17 @@ dependencies = [ "serde", ] +[[package]] +name = "smartstring" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb72c633efbaa2dd666986505016c32c3044395ceaf881518399d2f4127ee29" +dependencies = [ + "autocfg", + "static_assertions", + "version_check", +] + [[package]] name = "snapshots_creator" version = "0.1.0" @@ -6091,9 +6287,9 @@ dependencies = [ [[package]] name = "sqlx" -version = "0.7.3" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf" +checksum = "fcfa89bea9500db4a0d038513d7a060566bfc51d46d1c014847049a45cce85e8" dependencies = [ "sqlx-core", "sqlx-macros", @@ -6104,11 +6300,10 @@ dependencies = [ [[package]] name = "sqlx-core" -version = "0.7.3" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd" +checksum = "d06e2f2bd861719b1f3f0c7dbe1d80c30bf59e76cf019f07d9014ed7eefb8e08" dependencies = [ - "ahash 0.8.7", "atoi", "bigdecimal", "byteorder", @@ -6116,7 +6311,6 @@ dependencies = [ "chrono", "crc", "crossbeam-queue 0.3.8", - "dotenvy", "either", "event-listener", "futures-channel", @@ -6124,6 +6318,7 @@ dependencies = [ "futures-intrusive", "futures-io", "futures-util", + "hashbrown 0.14.5", "hashlink", "hex", "indexmap 2.1.0", @@ -6149,31 +6344,30 @@ dependencies = [ [[package]] name = "sqlx-macros" -version = "0.7.3" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5" +checksum = "2f998a9defdbd48ed005a89362bd40dd2117502f15294f61c8d47034107dbbdc" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "sqlx-core", "sqlx-macros-core", - "syn 1.0.109", + "syn 2.0.89", ] [[package]] name = "sqlx-macros-core" -version = "0.7.3" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0bd4519486723648186a08785143599760f7cc81c52334a55d6a83ea1e20841" +checksum = "3d100558134176a2629d46cec0c8891ba0be8910f7896abfdb75ef4ab6f4e7ce" dependencies = [ - "atomic-write-file", "dotenvy", "either", - "heck 0.4.1", + "heck 0.5.0", "hex", "once_cell", - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "serde", "serde_json", "sha2 0.10.8", @@ -6181,7 +6375,7 @@ dependencies = [ "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", - "syn 1.0.109", + "syn 2.0.89", "tempfile", "tokio", "url", @@ -6189,12 +6383,12 @@ dependencies = [ [[package]] name = "sqlx-mysql" -version = "0.7.3" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4" +checksum = "936cac0ab331b14cb3921c62156d913e4c15b74fb6ec0f3146bd4ef6e4fb3c12" dependencies = [ "atoi", - "base64 0.21.5", + "base64 0.22.1", "bigdecimal", "bitflags 2.6.0", "byteorder", @@ -6234,12 +6428,10 @@ dependencies = [ [[package]] name = "sqlx-postgres" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24" +version = "0.8.1" dependencies = [ "atoi", - "base64 0.21.5", + "base64 0.22.1", "bigdecimal", "bitflags 2.6.0", "byteorder", @@ -6266,7 +6458,6 @@ dependencies = [ "rust_decimal", "serde", "serde_json", - "sha1", "sha2 0.10.8", "smallvec", "sqlx-core", @@ -6278,9 +6469,9 @@ dependencies = [ [[package]] name = "sqlx-sqlite" -version = "0.7.3" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490" +checksum = "a75b419c3c1b1697833dd927bdc4c6545a620bc1bbafabd44e1efbe9afcd337e" dependencies = [ "atoi", "chrono", @@ -6294,10 +6485,10 @@ dependencies = [ "log", "percent-encoding", "serde", + "serde_urlencoded", "sqlx-core", "tracing", "url", - "urlencoding", ] [[package]] @@ -6354,8 +6545,8 @@ checksum = "dcb5ae327f9cc13b68763b5749770cb9e048a99bd9dfdfa58d0cf05d5f64afe0" dependencies = [ "heck 0.3.3", "proc-macro-error", - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "syn 1.0.109", ] @@ -6375,8 +6566,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e385be0d24f186b4ce2f9982191e7101bb737312ad61c1f2f984f34bcf85d59" dependencies = [ "heck 0.4.1", - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "rustversion", "syn 1.0.109", ] @@ -6404,19 +6595,19 @@ version = "1.0.109" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "unicode-ident", ] [[package]] name = "syn" -version = "2.0.38" +version = "2.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" +checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", + "proc-macro2 1.0.92", + "quote 1.0.37", "unicode-ident", ] @@ -6427,9 +6618,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1329189c02ff984e9736652b1631330da25eaa6bc639089ed4915d25446cbe7b" dependencies = [ "proc-macro-error", - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -6494,15 +6685,15 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.8.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb94d2f3cc536af71caac6b6fcebf65860b347e7ce0cc9ebe8f70d3e521054ef" +checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" dependencies = [ "cfg-if 1.0.0", "fastrand", - "redox_syscall 0.3.5", + "once_cell", "rustix", - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -6529,9 +6720,9 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2cfbe7811249c4c914b06141b8ac0f2cee2733fb883d05eb19668a45fc60c3d5" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -6550,9 +6741,9 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8f546451eaa38373f549093fe9fd05e7d2bade739e2ddf834b9968621d60107" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -6585,9 +6776,9 @@ version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -6749,9 +6940,9 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -6814,21 +7005,19 @@ dependencies = [ ] [[package]] -name = "toml_datetime" -version = "0.6.6" +name = "toml" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" +dependencies = [ + "serde", +] [[package]] -name = "toml_edit" -version = "0.19.15" +name = "toml_datetime" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" -dependencies = [ - "indexmap 2.1.0", - "toml_datetime", - "winnow", -] +checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" [[package]] name = "toml_edit" @@ -6947,9 +7136,9 @@ version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -7157,7 +7346,7 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ad948c1cb799b1a70f836077721a92a35ac177d4daddf4c20a633786d4cf618" dependencies = [ - "quote 1.0.33", + "quote 1.0.37", "syn 1.0.109", ] @@ -7286,9 +7475,9 @@ name = "vise-macros" version = "0.1.0" source = "git+https://github.com/matter-labs/vise.git?rev=a5bb80c9ce7168663114ee30e794d6dc32159ee4#a5bb80c9ce7168663114ee30e794d6dc32159ee4" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -7359,9 +7548,9 @@ dependencies = [ "bumpalo", "log", "once_cell", - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", "wasm-bindgen-shared", ] @@ -7383,7 +7572,7 @@ version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" dependencies = [ - "quote 1.0.33", + "quote 1.0.37", "wasm-bindgen-macro-support", ] @@ -7393,9 +7582,9 @@ version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -7456,7 +7645,7 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fec781d48b41f8163426ed18e8fc2864c12937df9ce54c88ede7bd47270893e" dependencies = [ - "redox_syscall 0.4.1", + "redox_syscall", "wasite", ] @@ -7518,6 +7707,24 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.42.2" @@ -7548,6 +7755,22 @@ dependencies = [ "windows_x86_64_msvc 0.48.5", ] +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.42.2" @@ -7560,6 +7783,12 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + [[package]] name = "windows_aarch64_msvc" version = "0.42.2" @@ -7572,6 +7801,12 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + [[package]] name = "windows_i686_gnu" version = "0.42.2" @@ -7584,6 +7819,18 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + [[package]] name = "windows_i686_msvc" version = "0.42.2" @@ -7596,6 +7843,12 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + [[package]] name = "windows_x86_64_gnu" version = "0.42.2" @@ -7608,6 +7861,12 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + [[package]] name = "windows_x86_64_gnullvm" version = "0.42.2" @@ -7620,6 +7879,12 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + [[package]] name = "windows_x86_64_msvc" version = "0.42.2" @@ -7632,6 +7897,12 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "winnow" version = "0.5.17" @@ -7700,9 +7971,9 @@ version = "0.7.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3c129550b3e6de3fd0ba67ba5c81818f9805e58b8d7fee80a3a59d2c9fc601a" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -7720,9 +7991,9 @@ version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -8538,7 +8809,11 @@ dependencies = [ "assert_matches", "async-trait", "chrono", + "hex", "once_cell", + "reqwest 0.11.22", + "rust-s3", + "serde_json", "test-casing", "thiserror", "tokio", @@ -8549,6 +8824,7 @@ dependencies = [ "zksync_dal", "zksync_eth_client", "zksync_l1_contract_interface", + "zksync_mini_merkle_tree", "zksync_node_fee_model", "zksync_node_test_utils", "zksync_object_store", @@ -9013,9 +9289,9 @@ dependencies = [ name = "zksync_node_framework_derive" version = "0.1.0" dependencies = [ - "proc-macro2 1.0.69", - "quote 1.0.33", - "syn 2.0.38", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] @@ -9163,12 +9439,12 @@ dependencies = [ "anyhow", "heck 0.5.0", "prettyplease", - "proc-macro2 1.0.69", + "proc-macro2 1.0.92", "prost-build", "prost-reflect", "protox", - "quote 1.0.33", - "syn 2.0.38", + "quote 1.0.37", + "syn 2.0.89", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 432f0c031..1a2b8fec1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -99,12 +99,13 @@ categories = ["cryptography"] [workspace.dependencies] # "External" dependencies +openssl = "=0.10.66" anyhow = "1" assert_matches = "1.5" async-trait = "0.1" axum = "0.7.5" backon = "0.4.4" -bigdecimal = "0.3.0" +bigdecimal = "=0.4.5" bincode = "1" blake2 = "0.10" chrono = "0.4" @@ -159,7 +160,7 @@ serde_with = "1" serde_yaml = "0.9" sha2 = "0.10.8" sha3 = "0.10.8" -sqlx = "0.7.3" +sqlx = "=0.8.1" static_assertions = "1.1" structopt = "0.3.20" strum = "0.24" @@ -279,3 +280,6 @@ zksync_contract_verification_server = { path = "core/node/contract_verification_ zksync_node_api_server = { path = "core/node/api_server" } zksync_tee_verifier_input_producer = { path = "core/node/tee_verifier_input_producer" } zksync_base_token_adjuster = {path = "core/node/base_token_adjuster"} + +[patch.crates-io] +sqlx-postgres = { path = "./patches/sqlx-postgres" } diff --git a/README.md b/README.md index 013d932aa..0d312ec57 100644 --- a/README.md +++ b/README.md @@ -1,53 +1,424 @@ -# ZKsync Era: A ZK Rollup For Scaling Ethereum +# zkthunder User Document + +## Project Overview + +The zkthunder project includes three main directories: -[![Logo](eraLogo.png)](https://zksync.io/) +- `./local-setup`: Containing the docker-compose file that organizes the entire project and other necessary configuration files (e.g., explorer json) for blockchain. +- `./local-setup-test`: Some test scripts and contracts for developers to deploy and call the contracts on the blockchain. +- `./zkthunder`: An implementation of zero-knowledge proof based Mintlayer blockchain service. + +Following are the core components of zkthunder project: + +- **4EVERLAND**: A holistic storage network compatible with IPFS. We use it as an IPFS-like storage system to save all the blockchain batch information. +- **Mintlayer node and RPC wallet**: A Mintlayer node and a wallet should be deployed locally since the zkthunder server will interact with it. +- **zkthunder Docker Images**: The zkthunder server and other necessary services (explorer, reth node, etc.) are running in docker-compose cluster. + +## Dependencies + +This is a shorter version of the setup guide to make it easier for subsequent initializations. If it's the first time you're initializing the workspace, it's recommended that you read the whole guide below, as it provides more context and tips. +If you run on 'clean' Ubuntu on GCP: + +```sh +# Rust +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh +# NVM +curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.5/install.sh | bash +# All necessary stuff +sudo apt-get update +sudo apt-get install build-essential pkg-config cmake clang lldb lld libssl-dev postgresql apt-transport-https ca-certificates curl software-properties-common +# Install docker +curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - +sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu focal stable" +sudo apt install docker-ce +sudo usermod -aG docker ${USER} +# Stop default postgres (as we'll use the docker one) +sudo systemctl stop postgresql +sudo systemctl disable postgresql +# Start docker. +sudo systemctl start docker +# You might need to re-connect (due to usermod change). +# Node & yarn +nvm install 20 +# Important: there will be a note in the output to load +# new paths in your local session, either run it or reload the terminal. +npm install -g yarn +yarn set version 1.22.19 +# For running unit tests +cargo install cargo-nextest +# SQL tools +cargo install sqlx-cli --version 0.8.0 +``` + +## Build + +### Environment and Initialization + +First, you shall set the environment variable in zkthunder directory, in a terminal do: + +```sh +cd zkthunder -ZKsync Era is a layer 2 rollup that uses zero-knowledge proofs to scale Ethereum without compromising on security or -decentralization. Since it's EVM compatible (Solidity/Vyper), 99% of Ethereum projects can redeploy without refactoring -or re-auditing a single line of code. ZKsync Era also uses an LLVM-based compiler that will eventually let developers -write smart contracts in C++, Rust and other popular languages. +export ZKSYNC_HOME=`pwd` -## Knowledge Index +export PATH=$ZKSYNC_HOME/bin:$PATH +``` -The following questions will be answered by the following resources: +Then, use the built-in 'zk-tools' to initialize the project. In the same terminal, run: + +```sh +ZKSYNC_HOME=`pwd` PATH=$ZKSYNC_HOME/bin:$PATH zk -| Question | Resource | -| ------------------------------------------------------- | ---------------------------------------------- | -| What do I need to develop the project locally? | [development.md](docs/guides/development.md) | -| How can I set up my dev environment? | [setup-dev.md](docs/guides/setup-dev.md) | -| How can I run the project? | [launch.md](docs/guides/launch.md) | -| What is the logical project structure and architecture? | [architecture.md](docs/guides/architecture.md) | -| Where can I find protocol specs? | [specs.md](docs/specs/README.md) | -| Where can I find developer docs? | [docs](https://docs.zksync.io) | +ZKSYNC_HOME=`pwd` PATH=$ZKSYNC_HOME/bin:$PATH zk init +``` -## Policies +After doing this, you can also use the following command to start or stop the existing docker container: -- [Security policy](SECURITY.md) -- [Contribution policy](CONTRIBUTING.md) +```sh +ZKSYNC_HOME=`pwd` PATH=$ZKSYNC_HOME/bin:$PATH zk up -## License +ZKSYNC_HOME=`pwd` PATH=$ZKSYNC_HOME/bin:$PATH zk down +``` -ZKsync Era is distributed under the terms of either +### Build Images + +Now you can build docker images over a initialized zkthunder project: + +```sh +ZKSYNC_HOME=`pwd` PATH=$ZKSYNC_HOME/bin:$PATH zk docker build server-v2 --custom-tag "zkthunder" -- Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or ) -- MIT license ([LICENSE-MIT](LICENSE-MIT) or ) +ZKSYNC_HOME=`pwd` PATH=$ZKSYNC_HOME/bin:$PATH zk docker build local-node --custom-tag "zkthunder" +``` + +The built images will be used in the docker-compose cluster, and make sure you have built the server-v2 image at first. Otherwise the local-node image will fail. + +## Deploy + +### Mintlayer Node Deployment + +To run the zkthunder project, you shall have a Mintlayer node and a RPC wallet running locally. For example, if you have a official version of mintlayer-core, run following command in mintlayer-core directory: + +```sh +# run a node daemon +cargo run --release --bin node-daemon -- testnet 2>&1 | tee ../mintlayer.log +# run a RPC wallet daemon, in another terminal +cargo run --release --bin wallet-rpc-daemon -- testnet --rpc-no-authentication 2>&1 | tee ../wallet-cli.log +``` + +Then, use a python script(or other way you like) to open the wallet, of course you need a rich wallet address to send the transactions: -at your option. +```python +import requests +import json -## Official Links +rpc_url = 'http://127.0.0.1:13034' +headers = {'content-type': 'application/json'} -- [Website](https://zksync.io/) -- [GitHub](https://github.com/matter-labs) -- [ZK Credo](https://github.com/zksync/credo) -- [Twitter](https://twitter.com/zksync) -- [Twitter for Developers](https://twitter.com/zkSyncDevs) -- [Discord](https://join.zksync.dev/) -- [Mirror](https://zksync.mirror.xyz/) -- [Youtube](https://www.youtube.com/@zkSync-era) +payload = { + "method": "wallet_open", + "params": { + "path": "path/to/wallet.dat", + }, + "jsonrpc": "2.0", + "id": 1, + } +response = requests.post(rpc_url, data=json.dumps(payload), headers=headers) +print(response.json()) +``` -## Disclaimer +Note that the rpc_url is the local port of Mintlayer RPC wallet. + +### zkthunder Docker Deployment + +To deploy the zkthunder service, just run the scripts in the local-setup directory, make sure that there are no other related container running: + +```sh +cd ../local-setup -ZKsync Era has been through lots of testing and audits. Although it is live, it is still in alpha state and will go -through more audits and bug bounty programs. We would love to hear our community's thoughts and suggestions about it! It -is important to state that forking it now can potentially lead to missing important security updates, critical features, -and performance improvements. +sudo ./start.sh +``` + +The script will bootstrap a docker cluster, which contains a complete zkthunder running service. If it works, you may see the output in terminal like this, which means the docker cluster is running normally: + +```sh +... +zkthunder-1| 2024-08-01T07:25:32.922492Z INFO loop_iteration{l1_block_numbers=L1BlockNumbers { safe: L1BlockNumber(847), finalized: L1BlockNumber(847), latest: L1BlockNumber(848) }}: zksync_eth_sender::eth_tx_manager: Loop iteration at block 848 +zkthunder-1| 2024-08-01T07:25:32.923338Z INFO loop_iteration{l1_block_numbers=L1BlockNumbers { safe: L1BlockNumber(847), finalized: L1BlockNumber(847), latest: L1BlockNumber(848) }}: zksync_eth_sender::eth_tx_manager: Sending tx 38 at block 848 with base_fee_per_gas 1, priority_fee_per_gas 1000000000, blob_fee_per_gas None +... +``` + +Or you want to run zkthunder in background, just modify the `./local-setup/start.sh` script, plus -d at the end of command: + +```sh +# In ./start.sh +# docker compose up +docker compose up -d +``` + +To stop the zkthunder docker service, run: + +```sh +cd ../local-setup + +sudo ./clear.sh +``` + +### zkthunder Test + +With a running zkthunder docker cluster and a local Mintlayer node(as well as open wallet), you can do tests of deploying contracts and calling contracts by provided scripts. But first, you need to install the dependencies: + +```sh +cd ./local-setup-test +# This command will install dependencies +yarn +``` + +There are three example testing scripts and a contract in the directory. + +- local-setup-test/contracts + +1. **Greeter.so**. A solidity smart contract does nothing but greeting. + +- local-setup-test/scripts + +1. **run.ts** . A script of deploying a contract and calling a contract for 50 times. + +2. **run-many-users.ts** . A script for a list of addresses(10 rich wallets) of deploying a contract and calling a contract for 10 times. + +- local-setup-test/test + +1. **main.test.ts** . A script of deploying a contract and calling a contract for 10 times. + +To run the various tests, follow the below command: + +```sh +# simply run main.test.ts +yarn test +# run run.ts with hardhat +NODE_ENV=test npx hardhat run ./scripts/run.ts +# run run-many-user.ts with multi-address +sudo bash ./bandwidth.sh +``` + +The configuration of hardhat, including the endpoints of local tests, is in file ./local-setup-test/hardhat.config.ts + +### Rich Wallets + +The tests always need some rich wallet addresses, with large amounts of ETH on both L1 and L2. you can find it in `./local-setup/rich-wallets.json` + +Also, during the initial bootstrapping of the system, several ERC-20 contracts are deployed locally. Note, that large quantities of these ERC-20 belong to the wallet`0x36615Cf349d7F6344891B1e7CA7C72883F5dc049`(the first one in the list of the rich wallet). Right after bootstrapping the system, these ERC-20 funds are available only on L1. + +## Docker-compose Configuration + +Now let’s make a deep dive into the docker-compose.yaml to see how the zkthunder work. + +This docker compose is setting up the full zkthunder network, consisting of: + +- L1 (private reth) with explorer (blockscout) + +- a single postgres (with all the databases) + +- L2 zkthunder chain, together with its explorer + +- hyperexplorer to merge L1, L2 all together. + +For the ports setting: + +- hyperexplorer: + +1. - http + +- L1 chain: + +1. 15045 - http + +- L1 explorer + +1. - http + +- L2 chain (zkthunder): + +1. - http rpc +2. - ws rpc + +- L2 explorer: + +1. - http +2. 3020 - explorer api +3. 15103 - explorer worker +4. 15104 - explorer data-fetcher +5. 15105 - explorer api metrics + +In this section, we focus on introducing the services named proxy-relay and zkthunder, see their settings in docker-compose.yaml below: + +```yaml +# zkthunder + +proxy-relay: + image: alpine/socat:latest + network_mode: host + command: TCP-LISTEN:13034,fork,bind=host.docker.internal TCP-CONNECT:127.0.0.1:13034 + extra_hosts: + - host.docker.internal:host-gateway + + zkthunder: + stdin_open: true + tty: true + image: matterlabs/local-node:${INSTANCE_TYPE:-zkthunder} + healthcheck: + test: curl --fail http://localhost:3071/health || exit 1 + interval: 10s + timeout: 5s + retries: 200 + start_period: 30s + environment: + - DATABASE_PROVER_URL=postgresql://postgres:notsecurepassword@postgres:5432/prover_local + - DATABASE_URL=postgresql://postgres:notsecurepassword@postgres:5432/zksync_local + - ETH_CLIENT_WEB3_URL=http://reth:8545 + - LEGACY_BRIDGE_TESTING=1 + # - IPFS_API_URL=http://ipfs:5001 + - ML_RPC_URL=http://host.docker.internal:13034 # change to mainnet if needed + - ML_BATCH_SIZE=10 # change if necessary + - 4EVERLAND_API_KEY=5F2R8SK2EQNSNCHSRWIK # only for test + - 4EVERLAND_SECRET_KEY=sCGfIdQZfis8YVCXnQP53SL8cPdRxyzjPLh1KYmF # only for test + - 4EVERLAND_BUCKET_NAME=zkthunder # only for test + ports: + - 15100:3050 # JSON RPC HTTP port + - 15101:3051 # JSON RPC WS port + depends_on: + - reth + - postgres + - proxy-relay + volumes: + - shared_config:/etc/env/target + - shared_tokens:/etc/tokens + extra_hosts: + - host.docker.internal:host-gateway + +``` + +The proxy-relay service forwards the request inside the docker to the local address on the machine, so the service inside the docker can access the Mintlayer network. + +In zkthunder’s environment settings: + +- **ML_RPC_URL** stands for the RPC wallet port of Mintlayer. + +- **ML_BATCH_SIZE** controls the frequency of sending data to Mintlayer. + +- **4EVERLAND_API_KEY, 4EVERLAND_SECRET_KEY, 4EVERLAND_BUCKET_NAME** these three variables stand for a specific bucket on 4everland, we upload the block information to it. + +Next section we will provide a detailed explanation of how we deal with the data storage on 4everland. + +## 4EVERLAND Storage + +There are three types of L2 batches, named Commit, Prove and Execute. Each batch will include block metadata, state root, system log, ZK proofs etc. We fetch the data of each batch, and send it to a specific 4everland bucket. + +```rust +// put this document to 4everland/ipfs + +let response_data = bucket + .put_object_stream(&mut contents, ipfs_doc_name.clone()) + .await + .unwrap(); +tracing::info!( "put {} to ipfs and get response code: {:?}", + ipfs_doc_name, + response_data.status_code() + ); +``` + +Note that three batches is related to one block. And every time we add a batch’s data to the 4everland bucket, the storage network will respond with an ipfs hash value. We collect such value until the number of responses reach the threshold of **BATCH_SIZE*3**. + +Then, we upload these all hash values as a file to 4everland storage: + +```rust + +// if block_number reaches the BATCH_SIZE, report the hashes to ipfs and then mintlayer +let batch_size: usize = env::var("ML_BATCH_SIZE") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(10 as usize); + +// the number of aggregated operations for mintlayer, default to 10 + +… + +``` + +```rust +let root_hash: Option = if self.ipfs_hash_queue.len() == hash_queue_limit { + let title = format!( + "batch_{}_{}", + self.ipfs_hash_queue[0], + self.ipfs_hash_queue.last().unwrap() + ); + let contents = self.ipfs_hash_queue.clone(); + let mut data = Cursor::new(serde_json::to_string(&contents).unwrap()); + + // put this document to 4everland/ipfs + let response_data = bucket + .put_object_stream(&mut data, title.clone()) + .await + .unwrap(); + tracing::info!("put hashes {} to ipfs and get response code: {:?}", + title, + response_data.status_code() + ); +... +} +``` + +As before, the 4everland ipfs network will return a hash value, which stands for our file that stores all ipfs hashs of batch information. We choose to save this “overall” root hash value to Mintlayer network, use a `address_deposit_data` method: + +```rust +if root_hash.is_some() { + // mintlayer + let mintlayer_rpc_url = env::var("ML_RPC_URL").unwrap(); + let mintlayer_client = Client::new(); + let headers = { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("Content-Type", "application/json".parse().unwrap()); + headers + }; + + // add the digest to mintlayer + let payload = json!({ + "method": "address_deposit_data", + "params": { + "data": hex::encode(root_hash.unwrap()), + // try to convert the hash to hex string according to ASCII + "account": 0, // default to use account 0 + "options": {} + }, + "jsonrpc": "2.0", + "id": 1, + }); + let response = mintlayer_client + .post(&mintlayer_rpc_url) + .headers(headers) + .json(&payload) + .send() + .await + .unwrap(); + +…} + +``` + +## Development + +One can easily develop his/her own zkthunder service by modifying the zkthunder code. The following command may help you quickly run the service: + +```sh +# enable zk tools +ZKSYNC_HOME=`pwd` PATH=$ZKSYNC_HOME/bin:$PATH zk +# init the project +ZKSYNC_HOME=`pwd` PATH=$ZKSYNC_HOME/bin:$PATH zk init +# start the docker container +ZKSYNC_HOME=`pwd` PATH=$ZKSYNC_HOME/bin:$PATH zk up +# start the zkthunder server +ZKSYNC_HOME=`pwd` PATH=$ZKSYNC_HOME/bin:$PATH zk server +# stop the zkthunder container +ZKSYNC_HOME=`pwd` PATH=$ZKSYNC_HOME/bin:$PATH zk down +# clean all the generated stuff by zk init +ZKSYNC_HOME=`pwd` PATH=$ZKSYNC_HOME/bin:$PATH zk clean --all +``` diff --git a/README_zksync.md b/README_zksync.md new file mode 100644 index 000000000..013d932aa --- /dev/null +++ b/README_zksync.md @@ -0,0 +1,53 @@ +# ZKsync Era: A ZK Rollup For Scaling Ethereum + +[![Logo](eraLogo.png)](https://zksync.io/) + +ZKsync Era is a layer 2 rollup that uses zero-knowledge proofs to scale Ethereum without compromising on security or +decentralization. Since it's EVM compatible (Solidity/Vyper), 99% of Ethereum projects can redeploy without refactoring +or re-auditing a single line of code. ZKsync Era also uses an LLVM-based compiler that will eventually let developers +write smart contracts in C++, Rust and other popular languages. + +## Knowledge Index + +The following questions will be answered by the following resources: + +| Question | Resource | +| ------------------------------------------------------- | ---------------------------------------------- | +| What do I need to develop the project locally? | [development.md](docs/guides/development.md) | +| How can I set up my dev environment? | [setup-dev.md](docs/guides/setup-dev.md) | +| How can I run the project? | [launch.md](docs/guides/launch.md) | +| What is the logical project structure and architecture? | [architecture.md](docs/guides/architecture.md) | +| Where can I find protocol specs? | [specs.md](docs/specs/README.md) | +| Where can I find developer docs? | [docs](https://docs.zksync.io) | + +## Policies + +- [Security policy](SECURITY.md) +- [Contribution policy](CONTRIBUTING.md) + +## License + +ZKsync Era is distributed under the terms of either + +- Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or ) +- MIT license ([LICENSE-MIT](LICENSE-MIT) or ) + +at your option. + +## Official Links + +- [Website](https://zksync.io/) +- [GitHub](https://github.com/matter-labs) +- [ZK Credo](https://github.com/zksync/credo) +- [Twitter](https://twitter.com/zksync) +- [Twitter for Developers](https://twitter.com/zkSyncDevs) +- [Discord](https://join.zksync.dev/) +- [Mirror](https://zksync.mirror.xyz/) +- [Youtube](https://www.youtube.com/@zkSync-era) + +## Disclaimer + +ZKsync Era has been through lots of testing and audits. Although it is live, it is still in alpha state and will go +through more audits and bug bounty programs. We would love to hear our community's thoughts and suggestions about it! It +is important to state that forking it now can potentially lead to missing important security updates, critical features, +and performance improvements. diff --git a/core/node/eth_sender/src/eth_tx_aggregator.rs b/core/node/eth_sender/src/eth_tx_aggregator.rs index dbfe3be94..4236ae280 100644 --- a/core/node/eth_sender/src/eth_tx_aggregator.rs +++ b/core/node/eth_sender/src/eth_tx_aggregator.rs @@ -393,7 +393,7 @@ impl EthTxAggregator { .await?; Self::report_eth_tx_saving(storage, &agg_op, &tx).await; - // zkmintlayer: A method `save_mintlayer_tx` to send the op to ipfs and mintlayer. + // zkthunder: A method `save_mintlayer_tx` to send the op to ipfs and mintlayer. self.save_mintlayer_tx(&agg_op).await; } diff --git a/deny.toml b/deny.toml index b50b165b7..c9c5a0959 100644 --- a/deny.toml +++ b/deny.toml @@ -6,9 +6,6 @@ vulnerability = "deny" unmaintained = "warn" yanked = "warn" notice = "warn" -ignore = [ - "RUSTSEC-2023-0018", -] [licenses] unlicensed = "deny" diff --git a/docker/local-node/Dockerfile b/docker/local-node/Dockerfile index 0a85d80eb..9425aeabf 100644 --- a/docker/local-node/Dockerfile +++ b/docker/local-node/Dockerfile @@ -1,7 +1,7 @@ # Image is always built from the server image to reuse the common parts # This image is expected to be built locally beforehand (implemented in the `zk` tool) # ARG BASE_VERSION=latest2.0 -ARG BASE_VERSION=zkmintlayer +ARG BASE_VERSION=zkthunder FROM matterlabs/server-v2:${BASE_VERSION} WORKDIR / diff --git a/local-setup-test/README.md b/local-setup-test/README.md index b82fb7ae6..690c31dcd 100644 --- a/local-setup-test/README.md +++ b/local-setup-test/README.md @@ -31,7 +31,7 @@ Make sure to get the correct bridgehub address (in this example: 0x35A3783781DE0 reth: only "reth:v0.2.0-beta.2" instance-type: -- zkmintlayer: can deploy contract and transfer +- zkthunder: can deploy contract and transfer - latest2.0: can transfer - hyperlocal: can deploy contract and transfer diff --git a/local-setup/docker-compose-dev.yml b/local-setup/docker-compose-dev.yml index a737ed136..ce3cb5829 100644 --- a/local-setup/docker-compose-dev.yml +++ b/local-setup/docker-compose-dev.yml @@ -1,13 +1,13 @@ # - L1 (reth) with explorer (blockscout) # - a single postgres (with all the databases) -# - L2 zkmintlayer chain, together with its explorer +# - L2 zkthunder chain, together with its explorer # Ports: # - l1 explorer: http://localhost:25001 (also using 25001, 25002, 25003) # - L1 chain (reth): # - 25045 - rpc -# - L2 chain (zkmintlayer): +# - L2 chain (zkthunder): # - 25100 - http # - 25101 - ws @@ -52,10 +52,10 @@ services: POSTGRES_PORT: 5432 restart: unless-stopped - zkmintlayer: + zkthunder: stdin_open: true tty: true - image: matterlabs/local-node:${INSTANCE_TYPE:-zkmintlayer-dev} + image: matterlabs/local-node:${INSTANCE_TYPE:-zkthunder-dev} healthcheck: test: curl --fail http://localhost:3071/health || exit 1 interval: 10s @@ -72,7 +72,7 @@ services: - ML_BATCH_SIZE=10 # change if necessary - 4EVERLAND_API_KEY=5F2R8SK2EQNSNCHSRWIK # only for test - 4EVERLAND_SECRET_KEY=sCGfIdQZfis8YVCXnQP53SL8cPdRxyzjPLh1KYmF # only for test - - 4EVERLAND_BUCKET_NAME=zkmintlayer # only for test + - 4EVERLAND_BUCKET_NAME=zkthunder # only for test ports: - 127.0.0.1:25100:3050 # JSON RPC HTTP port - 127.0.0.1:25101:3051 # JSON RPC WS port diff --git a/local-setup/docker-compose.yml b/local-setup/docker-compose.yml index b8a3a74e4..cee129a4f 100644 --- a/local-setup/docker-compose.yml +++ b/local-setup/docker-compose.yml @@ -1,9 +1,9 @@ -# This docker compose is setting up the full ZKMintlayer network, consisting of: +# This docker compose is setting up the full zkthunder network, consisting of: # # - L1 (private reth) with explorer (blockscout) # - a single postgres (with all the databases) # - a ipfs node -# - L2 zkmintlayer chain, together with its explorer +# - L2 zkthunder chain, together with its explorer # - hyperexplorer to merge L1, L2 all together. # Ports (if a port is written in the form http://localhost:PORT, it means that it can be accessed from the other machine): @@ -17,7 +17,7 @@ # - http://localhost:15002 - # - http://localhost:15003 - -# - L2 chain (zkmintlayer): +# - L2 chain (zkthunder): # - http://localhost:15100 - http rpc # - http://localhost:15101 - ws rpc # - L2 explorer: @@ -30,7 +30,7 @@ # Database is on 15432 # pgAdmin to manage PostgreSQL DB is on 15430 -# Besides, mintlayer rpc is on 13034/3034, change this in the zkmintlayer service if needed. +# Besides, mintlayer rpc is on 13034/3034, change this in the zkthunder service if needed. services: reth: @@ -90,7 +90,7 @@ services: POSTGRES_PORT: 5432 restart: unless-stopped - # zkmintlayer + # zkthunder proxy-relay: image: alpine/socat:latest network_mode: host @@ -98,10 +98,10 @@ services: extra_hosts: - host.docker.internal:host-gateway - zkmintlayer: + zkthunder: stdin_open: true tty: true - image: matterlabs/local-node:${INSTANCE_TYPE:-zkmintlayer} + image: matterlabs/local-node:${INSTANCE_TYPE:-zkthunder} healthcheck: test: curl --fail http://localhost:3071/health || exit 1 interval: 10s @@ -118,7 +118,7 @@ services: - ML_BATCH_SIZE=10 # change if necessary - 4EVERLAND_API_KEY=5F2R8SK2EQNSNCHSRWIK # only for test - 4EVERLAND_SECRET_KEY=sCGfIdQZfis8YVCXnQP53SL8cPdRxyzjPLh1KYmF # only for test - - 4EVERLAND_BUCKET_NAME=zkmintlayer # only for test + - 4EVERLAND_BUCKET_NAME=zkthunder # only for test ports: - 15100:3050 # JSON RPC HTTP port - 15101:3051 # JSON RPC WS port @@ -139,7 +139,7 @@ services: environment: - PORT=3040 - LOG_LEVEL=verbose - - BLOCKCHAIN_RPC_URL=http://zkmintlayer:3050 + - BLOCKCHAIN_RPC_URL=http://zkthunder:3050 ports: - 127.0.0.1:15104:3040 restart: unless-stopped @@ -155,7 +155,7 @@ services: - DATABASE_USER=postgres - DATABASE_PASSWORD=notsecurepassword - DATABASE_NAME=block-explorer - - BLOCKCHAIN_RPC_URL=http://zkmintlayer:3050 + - BLOCKCHAIN_RPC_URL=http://zkthunder:3050 - DATA_FETCHER_URL=http://data-fetcher-main:3040 - BATCHES_PROCESSING_POLLING_INTERVAL=1000 ports: @@ -264,7 +264,7 @@ services: hyperexplorer: depends_on: - zkmintlayer: + zkthunder: condition: service_healthy image: ghcr.io/mm-zk/zksync_tools:latest ports: diff --git a/local-setup/hyperexplorer.json b/local-setup/hyperexplorer.json index 0c955a96c..bc89a7d97 100644 --- a/local-setup/hyperexplorer.json +++ b/local-setup/hyperexplorer.json @@ -8,9 +8,9 @@ "shared_bridges": { "kl_exp": { "chains": { - "zkmintlayer": { + "zkthunder": { "chain_id": "0x10e", - "l2_url": "http://zkmintlayer:3050", + "l2_url": "http://zkthunder:3050", "explorer": "http://localhost:15005/?network=local", "type": "rollup" } diff --git a/local-setup/start-dev.sh b/local-setup/start-dev.sh index ac865efa5..01b847c08 100644 --- a/local-setup/start-dev.sh +++ b/local-setup/start-dev.sh @@ -6,16 +6,16 @@ # see https://hub.docker.com/r/matterlabs/local-node/tags for full list. # latest2.0 - is the 'main' one. -INSTANCE_TYPE=${1:-zkmintlayer-dev} +INSTANCE_TYPE=${1:-zkthunder-dev} export INSTANCE_TYPE=$INSTANCE_TYPE -echo "Starting ZKMintlayer Dev with instance type: $INSTANCE_TYPE" +echo "Starting zkthunder Dev with instance type: $INSTANCE_TYPE" docker compose -f docker-compose-dev.yml pull docker compose -f docker-compose-dev.yml up # docker compose -f docker-compose-dev.yml up -d check_all_services_healthy() { - service="zkmintlayer" + service="zkthunder" # service="zksync" (docker compose ps $service | grep "(healthy)") if [ $? -eq 0 ]; then @@ -39,7 +39,7 @@ DARKGRAY='\033[0;30m' ORANGE='\033[0;33m' echo -e "${GREEN}" -echo -e "SUCCESS, Your local ZKMintlayer Dev is now running! Find the information below for accessing each service." +echo -e "SUCCESS, Your local zkthunder Dev is now running! Find the information below for accessing each service." echo -e "┌──────────────────────────┬────────────────────────┬──────────────────────────────────────────────────┐" echo -e "│ Service │ URL │ Description │" echo -e "├──────────────────────────┼────────────────────────┼──────────────────────────────────────────────────┤" diff --git a/local-setup/start.sh b/local-setup/start.sh index 155ffee53..f078891ed 100644 --- a/local-setup/start.sh +++ b/local-setup/start.sh @@ -6,16 +6,16 @@ # see https://hub.docker.com/r/matterlabs/local-node/tags for full list. # latest2.0 - is the 'main' one. -INSTANCE_TYPE=${1:-zkmintlayer} +INSTANCE_TYPE=${1:-zkthunder} export INSTANCE_TYPE=$INSTANCE_TYPE -echo "Starting ZKMintlayer with instance type: $INSTANCE_TYPE" +echo "Starting zkthunder with instance type: $INSTANCE_TYPE" docker compose pull # docker compose up docker compose up check_all_services_healthy() { - service="zkmintlayer" + service="zkthunder" # service="zksync" (docker compose ps $service | grep "(healthy)") if [ $? -eq 0 ]; then @@ -39,7 +39,7 @@ DARKGRAY='\033[0;30m' ORANGE='\033[0;33m' echo -e "${GREEN}" -echo -e "SUCCESS, Your local ZKMintlayer is now running! Find the information below for accessing each service." +echo -e "SUCCESS, Your local zkthunder is now running! Find the information below for accessing each service." echo -e "┌──────────────────────────┬────────────────────────┬──────────────────────────────────────────────────┐" echo -e "│ Service │ URL │ Description │" echo -e "├──────────────────────────┼────────────────────────┼──────────────────────────────────────────────────┤" diff --git a/patches/sqlx-postgres/Cargo.toml b/patches/sqlx-postgres/Cargo.toml new file mode 100644 index 000000000..9c75f4565 --- /dev/null +++ b/patches/sqlx-postgres/Cargo.toml @@ -0,0 +1,262 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2021" +name = "sqlx-postgres" +version = "0.8.1" +authors = [ + "Ryan Leckey ", + "Austin Bonander ", + "Chloe Ross ", + "Daniel Akhterov ", +] +build = false +autobins = false +autoexamples = false +autotests = false +autobenches = false +description = "PostgreSQL driver implementation for SQLx. Not for direct use; see the `sqlx` crate for details." +documentation = "https://docs.rs/sqlx" +readme = false +license = "MIT OR Apache-2.0" +repository = "https://github.com/launchbadge/sqlx" + +[lib] +name = "sqlx_postgres" +path = "src/lib.rs" + +[dependencies.atoi] +version = "2.0" + +[dependencies.base64] +version = "0.22.0" +features = ["std"] +default-features = false + +[dependencies.bigdecimal] +version = "0.4.0" +optional = true + +[dependencies.bit-vec] +version = "0.6.3" +optional = true + +[dependencies.bitflags] +version = "2" +default-features = false + +[dependencies.byteorder] +version = "1.4.3" +features = ["std"] +default-features = false + +[dependencies.chrono] +version = "0.4.34" +features = [ + "std", + "clock", +] +optional = true +default-features = false + +[dependencies.crc] +version = "3.0.0" + +[dependencies.dotenvy] +version = "0.15.0" +default-features = false + +[dependencies.futures-channel] +version = "0.3.19" +features = [ + "sink", + "alloc", + "std", +] +default-features = false + +[dependencies.futures-core] +version = "0.3.19" +default-features = false + +[dependencies.futures-io] +version = "0.3.24" + +[dependencies.futures-util] +version = "0.3.19" +features = [ + "alloc", + "sink", + "io", +] +default-features = false + +[dependencies.hex] +version = "0.4.3" + +[dependencies.hkdf] +version = "0.12.0" + +[dependencies.hmac] +version = "0.12.0" +features = ["reset"] +default-features = false + +[dependencies.home] +version = "0.5.5" + +[dependencies.ipnetwork] +version = "0.20.0" +optional = true + +[dependencies.itoa] +version = "1.0.1" + +[dependencies.log] +version = "0.4.18" + +[dependencies.mac_address] +version = "1.1.5" +optional = true + +[dependencies.md-5] +version = "0.10.0" +default-features = false + +[dependencies.memchr] +version = "2.4.1" +default-features = false + +[dependencies.num-bigint] +version = "0.4.3" +optional = true + +[dependencies.once_cell] +version = "1.9.0" + +[dependencies.rand] +version = "0.8.4" +features = [ + "std", + "std_rng", +] +default-features = false + +[dependencies.rust_decimal] +version = "1.26.1" +features = ["std"] +optional = true +default-features = false + +[dependencies.serde] +version = "1.0.144" +features = ["derive"] + +[dependencies.serde_json] +version = "1.0.85" +features = ["raw_value"] + +[dependencies.sha2] +version = "0.10.0" +default-features = false + +[dependencies.smallvec] +version = "1.7.0" +features = ["serde"] + +[dependencies.sqlx-core] +version = "=0.8.1" +features = ["json"] + +[dependencies.stringprep] +version = "0.1.2" + +[dependencies.thiserror] +version = "1.0.35" + +[dependencies.time] +version = "0.3.36" +features = [ + "formatting", + "parsing", + "macros", +] +optional = true + +[dependencies.tracing] +version = "0.1.37" +features = ["log"] + +[dependencies.uuid] +version = "1.1.2" +optional = true + +[dependencies.whoami] +version = "1.2.1" +default-features = false + +[dev-dependencies.sqlx] +version = "=0.8.1" +features = [ + "postgres", + "derive", +] +default-features = false + +[features] +any = ["sqlx-core/any"] +bigdecimal = [ + "dep:bigdecimal", + "dep:num-bigint", + "sqlx-core/bigdecimal", +] +bit-vec = [ + "dep:bit-vec", + "sqlx-core/bit-vec", +] +chrono = [ + "dep:chrono", + "sqlx-core/chrono", +] +ipnetwork = [ + "dep:ipnetwork", + "sqlx-core/ipnetwork", +] +json = ["sqlx-core/json"] +mac_address = [ + "dep:mac_address", + "sqlx-core/mac_address", +] +migrate = ["sqlx-core/migrate"] +offline = ["sqlx-core/offline"] +rust_decimal = [ + "dep:rust_decimal", + "rust_decimal/maths", + "sqlx-core/rust_decimal", +] +time = [ + "dep:time", + "sqlx-core/time", +] +uuid = [ + "dep:uuid", + "sqlx-core/uuid", +] + +[target.'cfg(target_os = "windows")'.dependencies.etcetera] +version = "0.8.0" + +[lints.clippy] +cast_possible_truncation = "deny" +cast_possible_wrap = "deny" +cast_sign_loss = "deny" +disallowed_methods = "deny" diff --git a/patches/sqlx-postgres/LICENSE-APACHE b/patches/sqlx-postgres/LICENSE-APACHE new file mode 100644 index 000000000..c79147e87 --- /dev/null +++ b/patches/sqlx-postgres/LICENSE-APACHE @@ -0,0 +1,201 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, +and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by +the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all +other entities that control, are controlled by, or are under common +control with that entity. For the purposes of this definition, +"control" means (i) the power, direct or indirect, to cause the +direction or management of such entity, whether by contract or +otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity +exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, +including but not limited to software source code, documentation +source, and configuration files. + +"Object" form shall mean any form resulting from mechanical +transformation or translation of a Source form, including but +not limited to compiled object code, generated documentation, +and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or +Object form, made available under the License, as indicated by a +copyright notice that is included in or attached to the work +(an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object +form, that is based on (or derived from) the Work and for which the +editorial revisions, annotations, elaborations, or other modifications +represent, as a whole, an original work of authorship. For the purposes +of this License, Derivative Works shall not include works that remain +separable from, or merely link (or bind by name) to the interfaces of, +the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including +the original version of the Work and any modifications or additions +to that Work or Derivative Works thereof, that is intentionally +submitted to Licensor for inclusion in the Work by the copyright owner +or by an individual or Legal Entity authorized to submit on behalf of +the copyright owner. For the purposes of this definition, "submitted" +means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, +and issue tracking systems that are managed by, or on behalf of, the +Licensor for the purpose of discussing and improving the Work, but +excluding communication that is conspicuously marked or otherwise +designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity +on behalf of whom a Contribution has been received by Licensor and +subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the +Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +(except as stated in this section) patent license to make, have made, +use, offer to sell, sell, import, and otherwise transfer the Work, +where such license applies only to those patent claims licensable +by such Contributor that are necessarily infringed by their +Contribution(s) alone or by combination of their Contribution(s) +with the Work to which such Contribution(s) was submitted. If You +institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work +or a Contribution incorporated within the Work constitutes direct +or contributory patent infringement, then any patent licenses +granted to You under this License for that Work shall terminate +as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the +Work or Derivative Works thereof in any medium, with or without +modifications, and in Source or Object form, provided that You +meet the following conditions: + +(a) You must give any other recipients of the Work or +Derivative Works a copy of this License; and + +(b) You must cause any modified files to carry prominent notices +stating that You changed the files; and + +(c) You must retain, in the Source form of any Derivative Works +that You distribute, all copyright, patent, trademark, and +attribution notices from the Source form of the Work, +excluding those notices that do not pertain to any part of +the Derivative Works; and + +(d) If the Work includes a "NOTICE" text file as part of its +distribution, then any Derivative Works that You distribute must +include a readable copy of the attribution notices contained +within such NOTICE file, excluding those notices that do not +pertain to any part of the Derivative Works, in at least one +of the following places: within a NOTICE text file distributed +as part of the Derivative Works; within the Source form or +documentation, if provided along with the Derivative Works; or, +within a display generated by the Derivative Works, if and +wherever such third-party notices normally appear. The contents +of the NOTICE file are for informational purposes only and +do not modify the License. You may add Your own attribution +notices within Derivative Works that You distribute, alongside +or as an addendum to the NOTICE text from the Work, provided +that such additional attribution notices cannot be construed +as modifying the License. + +You may add Your own copyright statement to Your modifications and +may provide additional or different license terms and conditions +for use, reproduction, or distribution of Your modifications, or +for any such Derivative Works as a whole, provided Your use, +reproduction, and distribution of the Work otherwise complies with +the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, +any Contribution intentionally submitted for inclusion in the Work +by You to the Licensor shall be under the terms and conditions of +this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify +the terms of any separate license agreement you may have executed +with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade +names, trademarks, service marks, or product names of the Licensor, +except as required for reasonable and customary use in describing the +origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or +agreed to in writing, Licensor provides the Work (and each +Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied, including, without limitation, any warranties or conditions +of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A +PARTICULAR PURPOSE. You are solely responsible for determining the +appropriateness of using or redistributing the Work and assume any +risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, +whether in tort (including negligence), contract, or otherwise, +unless required by applicable law (such as deliberate and grossly +negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, +incidental, or consequential damages of any character arising as a +result of this License or out of the use or inability to use the +Work (including but not limited to damages for loss of goodwill, +work stoppage, computer failure or malfunction, or any and all +other commercial damages or losses), even if such Contributor +has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing +the Work or Derivative Works thereof, You may choose to offer, +and charge a fee for, acceptance of support, warranty, indemnity, +or other liability obligations and/or rights consistent with this +License. However, in accepting such obligations, You may act only +on Your own behalf and on Your sole responsibility, not on behalf +of any other Contributor, and only if You agree to indemnify, +defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason +of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following +boilerplate notice, with the fields enclosed by brackets "[]" +replaced with your own identifying information. (Don't include +the brackets!) The text should be enclosed in the appropriate +comment syntax for the file format. We also recommend that a +file or class name and description of purpose be included on the +same "printed page" as the copyright notice for easier +identification within third-party archives. + +Copyright 2020 LaunchBadge, LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. \ No newline at end of file diff --git a/patches/sqlx-postgres/LICENSE-MIT b/patches/sqlx-postgres/LICENSE-MIT new file mode 100644 index 000000000..861bf6085 --- /dev/null +++ b/patches/sqlx-postgres/LICENSE-MIT @@ -0,0 +1,25 @@ +Copyright (c) 2020 LaunchBadge, LLC + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/patches/sqlx-postgres/src/advisory_lock.rs b/patches/sqlx-postgres/src/advisory_lock.rs new file mode 100644 index 000000000..d1aef176f --- /dev/null +++ b/patches/sqlx-postgres/src/advisory_lock.rs @@ -0,0 +1,421 @@ +use crate::error::Result; +use crate::Either; +use crate::PgConnection; +use hkdf::Hkdf; +use once_cell::sync::OnceCell; +use sha2::Sha256; +use std::ops::{Deref, DerefMut}; + +/// A mutex-like type utilizing [Postgres advisory locks]. +/// +/// Advisory locks are a mechanism provided by Postgres to have mutually exclusive or shared +/// locks tracked in the database with application-defined semantics, as opposed to the standard +/// row-level or table-level locks which may not fit all use-cases. +/// +/// This API provides a convenient wrapper for generating and storing the integer keys that +/// advisory locks use, as well as RAII guards for releasing advisory locks when they fall out +/// of scope. +/// +/// This API only handles session-scoped advisory locks (explicitly locked and unlocked, or +/// automatically released when a connection is closed). +/// +/// It is also possible to use transaction-scoped locks but those can be used by beginning a +/// transaction and calling the appropriate lock functions (e.g. `SELECT pg_advisory_xact_lock()`) +/// manually, and cannot be explicitly released, but are automatically released when a transaction +/// ends (is committed or rolled back). +/// +/// Session-level locks can be acquired either inside or outside a transaction and are not +/// tied to transaction semantics; a lock acquired inside a transaction is still held when that +/// transaction is committed or rolled back, until explicitly released or the connection is closed. +/// +/// Locks can be acquired in either shared or exclusive modes, which can be thought of as read locks +/// and write locks, respectively. Multiple shared locks are allowed for the same key, but a single +/// exclusive lock prevents any other lock being taken for a given key until it is released. +/// +/// [Postgres advisory locks]: https://www.postgresql.org/docs/current/explicit-locking.html#ADVISORY-LOCKS +#[derive(Debug, Clone)] +pub struct PgAdvisoryLock { + key: PgAdvisoryLockKey, + /// The query to execute to release this lock. + release_query: OnceCell, +} + +/// A key type natively used by Postgres advisory locks. +/// +/// Currently, Postgres advisory locks have two different key spaces: one keyed by a single +/// 64-bit integer, and one keyed by a pair of two 32-bit integers. The Postgres docs +/// specify that these key spaces "do not overlap": +/// +/// +/// +/// The documentation for the `pg_locks` system view explains further how advisory locks +/// are treated in Postgres: +/// +/// +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum PgAdvisoryLockKey { + /// The keyspace designated by a single 64-bit integer. + /// + /// When [PgAdvisoryLock] is constructed with [::new()][PgAdvisoryLock::new()], + /// this is the keyspace used. + BigInt(i64), + /// The keyspace designated by two 32-bit integers. + IntPair(i32, i32), +} + +/// A wrapper for `PgConnection` (or a similar type) that represents a held Postgres advisory lock. +/// +/// Can be acquired by [`PgAdvisoryLock::acquire()`] or [`PgAdvisoryLock::try_acquire()`]. +/// Released on-drop or via [`Self::release_now()`]. +/// +/// ### Note: Release-on-drop is not immediate! +/// On drop, this guard queues a `pg_advisory_unlock()` call on the connection which will be +/// flushed to the server the next time it is used, or when it is returned to +/// a [`PgPool`][crate::PgPool] in the case of +/// [`PoolConnection`][crate::pool::PoolConnection]. +/// +/// This means the lock is not actually released as soon as the guard is dropped. To ensure the +/// lock is eagerly released, you can call [`.release_now().await`][Self::release_now()]. +pub struct PgAdvisoryLockGuard<'lock, C: AsMut> { + lock: &'lock PgAdvisoryLock, + conn: Option, +} + +impl PgAdvisoryLock { + /// Construct a `PgAdvisoryLock` using the given string as a key. + /// + /// This is intended to make it easier to use an advisory lock by using a human-readable string + /// for a key as opposed to manually generating a unique integer key. The generated integer key + /// is guaranteed to be stable and in the single 64-bit integer keyspace + /// (see [`PgAdvisoryLockKey`] for details). + /// + /// This is done by applying the [Hash-based Key Derivation Function (HKDF; IETF RFC 5869)][hkdf] + /// to the bytes of the input string, but in a way that the calculated integer is unlikely + /// to collide with any similar implementations (although we don't currently know of any). + /// See the source of this method for details. + /// + /// [hkdf]: https://datatracker.ietf.org/doc/html/rfc5869 + /// ### Example + /// ```rust + /// use sqlx::postgres::{PgAdvisoryLock, PgAdvisoryLockKey}; + /// + /// let lock = PgAdvisoryLock::new("my first Postgres advisory lock!"); + /// // Negative values are fine because of how Postgres treats advisory lock keys. + /// // See the documentation for the `pg_locks` system view for details. + /// assert_eq!(lock.key(), &PgAdvisoryLockKey::BigInt(-5560419505042474287)); + /// ``` + pub fn new(key_string: impl AsRef) -> Self { + let input_key_material = key_string.as_ref(); + + // HKDF was chosen because it is designed to concentrate the entropy in a variable-length + // input key and produce a higher quality but reduced-length output key with a + // well-specified and reproducible algorithm. + // + // Granted, the input key is usually meant to be pseudorandom and not human readable, + // but we're not trying to produce an unguessable value by any means; just one that's as + // unlikely to already be in use as possible, but still deterministic. + // + // SHA-256 was chosen as the hash function because it's already used in the Postgres driver, + // which should save on codegen and optimization. + + // We don't supply a salt as that is intended to be random, but we want a deterministic key. + let hkdf = Hkdf::::new(None, input_key_material.as_bytes()); + + let mut output_key_material = [0u8; 8]; + + // The first string is the "info" string of the HKDF which is intended to tie the output + // exclusively to SQLx. This should avoid collisions with implementations using a similar + // strategy. If you _want_ this to match some other implementation then you should get + // the calculated integer key from it and use that directly. + // + // Do *not* change this string as it will affect the output! + hkdf.expand( + b"SQLx (Rust) Postgres advisory lock", + &mut output_key_material, + ) + // `Hkdf::expand()` only returns an error if you ask for more than 255 times the digest size. + // This is specified by RFC 5869 but not elaborated upon: + // https://datatracker.ietf.org/doc/html/rfc5869#section-2.3 + // Since we're only asking for 8 bytes, this error shouldn't be returned. + .expect("BUG: `output_key_material` should be of acceptable length"); + + // For ease of use, this method assumes the user doesn't care which keyspace is used. + // + // It doesn't seem likely that someone would care about using the `(int, int)` keyspace + // specifically unless they already had keys to use, in which case they wouldn't + // care about this method. That's why we also provide `with_key()`. + // + // The choice of `from_le_bytes()` is mostly due to x86 being the most popular + // architecture for server software, so it should be a no-op there. + let key = PgAdvisoryLockKey::BigInt(i64::from_le_bytes(output_key_material)); + + tracing::trace!( + ?key, + key_string = ?input_key_material, + "generated key from key string", + ); + + Self::with_key(key) + } + + /// Construct a `PgAdvisoryLock` with a manually supplied key. + pub fn with_key(key: PgAdvisoryLockKey) -> Self { + Self { + key, + release_query: OnceCell::new(), + } + } + + /// Returns the current key. + pub fn key(&self) -> &PgAdvisoryLockKey { + &self.key + } + + // Why doesn't this use `Acquire`? Well, I tried it and got really useless errors + // about "cannot project lifetimes to parent scope". + // + // It has something to do with how lifetimes work on the `Acquire` trait, I couldn't + // be bothered to figure it out. Probably another issue with a lack of `async fn` in traits + // or lazy normalization. + + /// Acquires an exclusive lock using `pg_advisory_lock()`, waiting until the lock is acquired. + /// + /// For a version that returns immediately instead of waiting, see [`Self::try_acquire()`]. + /// + /// A connection-like type is required to execute the call. Allowed types include `PgConnection`, + /// `PoolConnection` and `Transaction`, as well as mutable references to + /// any of these. + /// + /// The returned guard queues a `pg_advisory_unlock()` call on the connection when dropped, + /// which will be executed the next time the connection is used, or when returned to a + /// [`PgPool`][crate::PgPool] in the case of `PoolConnection`. + /// + /// Postgres allows a single connection to acquire a given lock more than once without releasing + /// it first, so in that sense the lock is re-entrant. However, the number of unlock operations + /// must match the number of lock operations for the lock to actually be released. + /// + /// See [Postgres' documentation for the Advisory Lock Functions][advisory-funcs] for details. + /// + /// [advisory-funcs]: https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-ADVISORY-LOCKS + pub async fn acquire>( + &self, + mut conn: C, + ) -> Result> { + match &self.key { + PgAdvisoryLockKey::BigInt(key) => { + crate::query::query("SELECT pg_advisory_lock($1)") + .bind(key) + .execute(conn.as_mut()) + .await?; + } + PgAdvisoryLockKey::IntPair(key1, key2) => { + crate::query::query("SELECT pg_advisory_lock($1, $2)") + .bind(key1) + .bind(key2) + .execute(conn.as_mut()) + .await?; + } + } + + Ok(PgAdvisoryLockGuard::new(self, conn)) + } + + /// Acquires an exclusive lock using `pg_try_advisory_lock()`, returning immediately + /// if the lock could not be acquired. + /// + /// For a version that waits until the lock is acquired, see [`Self::acquire()`]. + /// + /// A connection-like type is required to execute the call. Allowed types include `PgConnection`, + /// `PoolConnection` and `Transaction`, as well as mutable references to + /// any of these. The connection is returned if the lock could not be acquired. + /// + /// The returned guard queues a `pg_advisory_unlock()` call on the connection when dropped, + /// which will be executed the next time the connection is used, or when returned to a + /// [`PgPool`][crate::PgPool] in the case of `PoolConnection`. + /// + /// Postgres allows a single connection to acquire a given lock more than once without releasing + /// it first, so in that sense the lock is re-entrant. However, the number of unlock operations + /// must match the number of lock operations for the lock to actually be released. + /// + /// See [Postgres' documentation for the Advisory Lock Functions][advisory-funcs] for details. + /// + /// [advisory-funcs]: https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-ADVISORY-LOCKS + pub async fn try_acquire>( + &self, + mut conn: C, + ) -> Result, C>> { + let locked: bool = match &self.key { + PgAdvisoryLockKey::BigInt(key) => { + crate::query_scalar::query_scalar("SELECT pg_try_advisory_lock($1)") + .bind(key) + .fetch_one(conn.as_mut()) + .await? + } + PgAdvisoryLockKey::IntPair(key1, key2) => { + crate::query_scalar::query_scalar("SELECT pg_try_advisory_lock($1, $2)") + .bind(key1) + .bind(key2) + .fetch_one(conn.as_mut()) + .await? + } + }; + + if locked { + Ok(Either::Left(PgAdvisoryLockGuard::new(self, conn))) + } else { + Ok(Either::Right(conn)) + } + } + + /// Execute `pg_advisory_unlock()` for this lock's key on the given connection. + /// + /// This is used by [`PgAdvisoryLockGuard::release_now()`] and is also provided for manually + /// releasing the lock from connections returned by [`PgAdvisoryLockGuard::leak()`]. + /// + /// An error should only be returned if there is something wrong with the connection, + /// in which case the lock will be automatically released by the connection closing anyway. + /// + /// The `boolean` value is that returned by `pg_advisory_lock()`. If it is `false`, it + /// indicates that the lock was not actually held by the given connection and that a warning + /// has been logged by the Postgres server. + pub async fn force_release>(&self, mut conn: C) -> Result<(C, bool)> { + let released: bool = match &self.key { + PgAdvisoryLockKey::BigInt(key) => { + crate::query_scalar::query_scalar("SELECT pg_advisory_unlock($1)") + .bind(key) + .fetch_one(conn.as_mut()) + .await? + } + PgAdvisoryLockKey::IntPair(key1, key2) => { + crate::query_scalar::query_scalar("SELECT pg_advisory_unlock($1, $2)") + .bind(key1) + .bind(key2) + .fetch_one(conn.as_mut()) + .await? + } + }; + + Ok((conn, released)) + } + + fn get_release_query(&self) -> &str { + self.release_query.get_or_init(|| match &self.key { + PgAdvisoryLockKey::BigInt(key) => format!("SELECT pg_advisory_unlock({key})"), + PgAdvisoryLockKey::IntPair(key1, key2) => { + format!("SELECT pg_advisory_unlock({key1}, {key2})") + } + }) + } +} + +impl PgAdvisoryLockKey { + /// Converts `Self::Bigint(bigint)` to `Some(bigint)` and all else to `None`. + pub fn as_bigint(&self) -> Option { + if let Self::BigInt(bigint) = self { + Some(*bigint) + } else { + None + } + } +} + +const NONE_ERR: &str = "BUG: PgAdvisoryLockGuard.conn taken"; + +impl<'lock, C: AsMut> PgAdvisoryLockGuard<'lock, C> { + fn new(lock: &'lock PgAdvisoryLock, conn: C) -> Self { + PgAdvisoryLockGuard { + lock, + conn: Some(conn), + } + } + + /// Immediately release the held advisory lock instead of when the connection is next used. + /// + /// An error should only be returned if there is something wrong with the connection, + /// in which case the lock will be automatically released by the connection closing anyway. + /// + /// If `pg_advisory_unlock()` returns `false`, a warning will be logged, both by SQLx as + /// well as the Postgres server. This would only happen if the lock was released without + /// using this guard, or the connection was swapped using [`std::mem::replace()`]. + pub async fn release_now(mut self) -> Result { + let (conn, released) = self + .lock + .force_release(self.conn.take().expect(NONE_ERR)) + .await?; + + if !released { + tracing::warn!( + lock = ?self.lock.key, + "PgAdvisoryLockGuard: advisory lock was not held by the contained connection", + ); + } + + Ok(conn) + } + + /// Cancel the release of the advisory lock, keeping it held until the connection is closed. + /// + /// To manually release the lock later, see [`PgAdvisoryLock::force_release()`]. + pub fn leak(mut self) -> C { + self.conn.take().expect(NONE_ERR) + } +} + +impl<'lock, C: AsMut + AsRef> Deref for PgAdvisoryLockGuard<'lock, C> { + type Target = PgConnection; + + fn deref(&self) -> &Self::Target { + self.conn.as_ref().expect(NONE_ERR).as_ref() + } +} + +/// Mutable access to the underlying connection is provided so it can still be used like normal, +/// even allowing locks to be taken recursively. +/// +/// However, replacing the connection with a different one using, e.g. [`std::mem::replace()`] +/// is a logic error and will cause a warning to be logged by the PostgreSQL server when this +/// guard attempts to release the lock. +impl<'lock, C: AsMut + AsRef> DerefMut + for PgAdvisoryLockGuard<'lock, C> +{ + fn deref_mut(&mut self) -> &mut Self::Target { + self.conn.as_mut().expect(NONE_ERR).as_mut() + } +} + +impl<'lock, C: AsMut + AsRef> AsRef + for PgAdvisoryLockGuard<'lock, C> +{ + fn as_ref(&self) -> &PgConnection { + self.conn.as_ref().expect(NONE_ERR).as_ref() + } +} + +/// Mutable access to the underlying connection is provided so it can still be used like normal, +/// even allowing locks to be taken recursively. +/// +/// However, replacing the connection with a different one using, e.g. [`std::mem::replace()`] +/// is a logic error and will cause a warning to be logged by the PostgreSQL server when this +/// guard attempts to release the lock. +impl<'lock, C: AsMut> AsMut for PgAdvisoryLockGuard<'lock, C> { + fn as_mut(&mut self) -> &mut PgConnection { + self.conn.as_mut().expect(NONE_ERR).as_mut() + } +} + +/// Queues a `pg_advisory_unlock()` call on the wrapped connection which will be flushed +/// to the server the next time it is used, or when it is returned to [`PgPool`][crate::PgPool] +/// in the case of [`PoolConnection`][crate::pool::PoolConnection]. +impl<'lock, C: AsMut> Drop for PgAdvisoryLockGuard<'lock, C> { + fn drop(&mut self) { + if let Some(mut conn) = self.conn.take() { + // Queue a simple query message to execute next time the connection is used. + // The `async fn` versions can safely use the prepared statement protocol, + // but this is the safest way to queue a query to execute on the next opportunity. + conn.as_mut() + .queue_simple_query(self.lock.get_release_query()) + .expect("BUG: PgAdvisoryLock::get_release_query() somehow too long for protocol"); + } + } +} diff --git a/patches/sqlx-postgres/src/any.rs b/patches/sqlx-postgres/src/any.rs new file mode 100644 index 000000000..7eae4bcb7 --- /dev/null +++ b/patches/sqlx-postgres/src/any.rs @@ -0,0 +1,248 @@ +use crate::{ + Either, PgColumn, PgConnectOptions, PgConnection, PgQueryResult, PgRow, PgTransactionManager, + PgTypeInfo, Postgres, +}; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; +use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt}; +use std::future; + +pub use sqlx_core::any::*; + +use crate::type_info::PgType; +use sqlx_core::connection::Connection; +use sqlx_core::database::Database; +use sqlx_core::describe::Describe; +use sqlx_core::executor::Executor; +use sqlx_core::ext::ustr::UStr; +use sqlx_core::transaction::TransactionManager; + +sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Postgres); + +impl AnyConnectionBackend for PgConnection { + fn name(&self) -> &str { + ::NAME + } + + fn close(self: Box) -> BoxFuture<'static, sqlx_core::Result<()>> { + Connection::close(*self) + } + + fn close_hard(self: Box) -> BoxFuture<'static, sqlx_core::Result<()>> { + Connection::close_hard(*self) + } + + fn ping(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { + Connection::ping(self) + } + + fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { + PgTransactionManager::begin(self) + } + + fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { + PgTransactionManager::commit(self) + } + + fn rollback(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { + PgTransactionManager::rollback(self) + } + + fn start_rollback(&mut self) { + PgTransactionManager::start_rollback(self) + } + + fn shrink_buffers(&mut self) { + Connection::shrink_buffers(self); + } + + fn flush(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { + Connection::flush(self) + } + + fn should_flush(&self) -> bool { + Connection::should_flush(self) + } + + #[cfg(feature = "migrate")] + fn as_migrate( + &mut self, + ) -> sqlx_core::Result<&mut (dyn sqlx_core::migrate::Migrate + Send + 'static)> { + Ok(self) + } + + fn fetch_many<'q>( + &'q mut self, + query: &'q str, + persistent: bool, + arguments: Option>, + ) -> BoxStream<'q, sqlx_core::Result>> { + let persistent = persistent && arguments.is_some(); + let arguments = match arguments.as_ref().map(AnyArguments::convert_to).transpose() { + Ok(arguments) => arguments, + Err(error) => { + return stream::once(future::ready(Err(sqlx_core::Error::Encode(error)))).boxed() + } + }; + + Box::pin( + self.run(query, arguments, 0, persistent, None) + .try_flatten_stream() + .map( + move |res: sqlx_core::Result>| match res? { + Either::Left(result) => Ok(Either::Left(map_result(result))), + Either::Right(row) => Ok(Either::Right(AnyRow::try_from(&row)?)), + }, + ), + ) + } + + fn fetch_optional<'q>( + &'q mut self, + query: &'q str, + persistent: bool, + arguments: Option>, + ) -> BoxFuture<'q, sqlx_core::Result>> { + let persistent = persistent && arguments.is_some(); + let arguments = arguments + .as_ref() + .map(AnyArguments::convert_to) + .transpose() + .map_err(sqlx_core::Error::Encode); + + Box::pin(async move { + let arguments = arguments?; + let stream = self.run(query, arguments, 1, persistent, None).await?; + futures_util::pin_mut!(stream); + + if let Some(Either::Right(row)) = stream.try_next().await? { + return Ok(Some(AnyRow::try_from(&row)?)); + } + + Ok(None) + }) + } + + fn prepare_with<'c, 'q: 'c>( + &'c mut self, + sql: &'q str, + _parameters: &[AnyTypeInfo], + ) -> BoxFuture<'c, sqlx_core::Result>> { + Box::pin(async move { + let statement = Executor::prepare_with(self, sql, &[]).await?; + AnyStatement::try_from_statement( + sql, + &statement, + statement.metadata.column_names.clone(), + ) + }) + } + + fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, sqlx_core::Result>> { + Box::pin(async move { + let describe = Executor::describe(self, sql).await?; + + let columns = describe + .columns + .iter() + .map(AnyColumn::try_from) + .collect::, _>>()?; + + let parameters = match describe.parameters { + Some(Either::Left(parameters)) => Some(Either::Left( + parameters + .iter() + .enumerate() + .map(|(i, type_info)| { + AnyTypeInfo::try_from(type_info).map_err(|_| { + sqlx_core::Error::AnyDriverError( + format!( + "Any driver does not support type {type_info} of parameter {i}" + ) + .into(), + ) + }) + }) + .collect::, _>>()?, + )), + Some(Either::Right(count)) => Some(Either::Right(count)), + None => None, + }; + + Ok(Describe { + columns, + parameters, + nullable: describe.nullable, + }) + }) + } +} + +impl<'a> TryFrom<&'a PgTypeInfo> for AnyTypeInfo { + type Error = sqlx_core::Error; + + fn try_from(pg_type: &'a PgTypeInfo) -> Result { + Ok(AnyTypeInfo { + kind: match &pg_type.0 { + PgType::Bool => AnyTypeInfoKind::Bool, + PgType::Void => AnyTypeInfoKind::Null, + PgType::Int2 => AnyTypeInfoKind::SmallInt, + PgType::Int4 => AnyTypeInfoKind::Integer, + PgType::Int8 => AnyTypeInfoKind::BigInt, + PgType::Float4 => AnyTypeInfoKind::Real, + PgType::Float8 => AnyTypeInfoKind::Double, + PgType::Bytea => AnyTypeInfoKind::Blob, + PgType::Text | PgType::Varchar => AnyTypeInfoKind::Text, + PgType::DeclareWithName(UStr::Static("citext")) => AnyTypeInfoKind::Text, + _ => { + return Err(sqlx_core::Error::AnyDriverError( + format!("Any driver does not support the Postgres type {pg_type:?}").into(), + )) + } + }, + }) + } +} + +impl<'a> TryFrom<&'a PgColumn> for AnyColumn { + type Error = sqlx_core::Error; + + fn try_from(col: &'a PgColumn) -> Result { + let type_info = + AnyTypeInfo::try_from(&col.type_info).map_err(|e| sqlx_core::Error::ColumnDecode { + index: col.name.to_string(), + source: e.into(), + })?; + + Ok(AnyColumn { + ordinal: col.ordinal, + name: col.name.clone(), + type_info, + }) + } +} + +impl<'a> TryFrom<&'a PgRow> for AnyRow { + type Error = sqlx_core::Error; + + fn try_from(row: &'a PgRow) -> Result { + AnyRow::map_from(row, row.metadata.column_names.clone()) + } +} + +impl<'a> TryFrom<&'a AnyConnectOptions> for PgConnectOptions { + type Error = sqlx_core::Error; + + fn try_from(value: &'a AnyConnectOptions) -> Result { + let mut opts = PgConnectOptions::parse_from_url(&value.database_url)?; + opts.log_settings = value.log_settings.clone(); + Ok(opts) + } +} + +fn map_result(res: PgQueryResult) -> AnyQueryResult { + AnyQueryResult { + rows_affected: res.rows_affected(), + last_insert_id: None, + } +} diff --git a/patches/sqlx-postgres/src/arguments.rs b/patches/sqlx-postgres/src/arguments.rs new file mode 100644 index 000000000..bc7e861c5 --- /dev/null +++ b/patches/sqlx-postgres/src/arguments.rs @@ -0,0 +1,284 @@ +use std::fmt::{self, Write}; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; + +use crate::encode::{Encode, IsNull}; +use crate::error::Error; +use crate::ext::ustr::UStr; +use crate::types::Type; +use crate::{PgConnection, PgTypeInfo, Postgres}; + +use crate::type_info::PgArrayOf; +pub(crate) use sqlx_core::arguments::Arguments; +use sqlx_core::error::BoxDynError; + +// TODO: buf.patch(|| ...) is a poor name, can we think of a better name? Maybe `buf.lazy(||)` ? +// TODO: Extend the patch system to support dynamic lengths +// Considerations: +// - The prefixed-len offset needs to be back-tracked and updated +// - message::Bind needs to take a &PgArguments and use a `write` method instead of +// referencing a buffer directly +// - The basic idea is that we write bytes for the buffer until we get somewhere +// that has a patch, we then apply the patch which should write to &mut Vec, +// backtrack and update the prefixed-len, then write until the next patch offset + +#[derive(Default)] +pub struct PgArgumentBuffer { + buffer: Vec, + + // Number of arguments + count: usize, + + // Whenever an `Encode` impl needs to defer some work until after we resolve parameter types + // it can use `patch`. + // + // This currently is only setup to be useful if there is a *fixed-size* slot that needs to be + // tweaked from the input type. However, that's the only use case we currently have. + patches: Vec, + + // Whenever an `Encode` impl encounters a `PgTypeInfo` object that does not have an OID + // It pushes a "hole" that must be patched later. + // + // The hole is a `usize` offset into the buffer with the type name that should be resolved + // This is done for Records and Arrays as the OID is needed well before we are in an async + // function and can just ask postgres. + // + type_holes: Vec<(usize, HoleKind)>, // Vec<{ offset, type_name }> +} + +enum HoleKind { + Type { name: UStr }, + Array(Arc), +} + +struct Patch { + buf_offset: usize, + arg_index: usize, + #[allow(clippy::type_complexity)] + callback: Box, +} + +/// Implementation of [`Arguments`] for PostgreSQL. +#[derive(Default)] +pub struct PgArguments { + // Types of each bind parameter + pub(crate) types: Vec, + + // Buffer of encoded bind parameters + pub(crate) buffer: PgArgumentBuffer, +} + +impl PgArguments { + pub(crate) fn add<'q, T>(&mut self, value: T) -> Result<(), BoxDynError> + where + T: Encode<'q, Postgres> + Type, + { + let type_info = value.produces().unwrap_or_else(T::type_info); + + let buffer_snapshot = self.buffer.snapshot(); + + // encode the value into our buffer + if let Err(error) = self.buffer.encode(value) { + // reset the value buffer to its previous value if encoding failed, + // so we don't leave a half-encoded value behind + self.buffer.reset_to_snapshot(buffer_snapshot); + return Err(error); + }; + + // remember the type information for this value + self.types.push(type_info); + // increment the number of arguments we are tracking + self.buffer.count += 1; + + Ok(()) + } + + // Apply patches + // This should only go out and ask postgres if we have not seen the type name yet + pub(crate) async fn apply_patches( + &mut self, + conn: &mut PgConnection, + parameters: &[PgTypeInfo], + ) -> Result<(), Error> { + let PgArgumentBuffer { + ref patches, + ref type_holes, + ref mut buffer, + .. + } = self.buffer; + + for patch in patches { + let buf = &mut buffer[patch.buf_offset..]; + let ty = ¶meters[patch.arg_index]; + + (patch.callback)(buf, ty); + } + + for (offset, kind) in type_holes { + let oid = match kind { + HoleKind::Type { name } => conn.fetch_type_id_by_name(name).await?, + HoleKind::Array(array) => conn.fetch_array_type_id(array).await?, + }; + buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes()); + } + + Ok(()) + } +} + +impl<'q> Arguments<'q> for PgArguments { + type Database = Postgres; + + fn reserve(&mut self, additional: usize, size: usize) { + self.types.reserve(additional); + self.buffer.reserve(size); + } + + fn add(&mut self, value: T) -> Result<(), BoxDynError> + where + T: Encode<'q, Self::Database> + Type, + { + self.add(value) + } + + fn format_placeholder(&self, writer: &mut W) -> fmt::Result { + write!(writer, "${}", self.buffer.count) + } + + #[inline(always)] + fn len(&self) -> usize { + self.buffer.count + } +} + +impl PgArgumentBuffer { + pub(crate) fn encode<'q, T>(&mut self, value: T) -> Result<(), BoxDynError> + where + T: Encode<'q, Postgres>, + { + // Won't catch everything but is a good sanity check + value_size_int4_checked(value.size_hint())?; + + // reserve space to write the prefixed length of the value + let offset = self.len(); + + self.extend(&[0; 4]); + + // encode the value into our buffer + let len = if let IsNull::No = value.encode(self)? { + // Ensure that the value size does not overflow i32 + value_size_int4_checked(self.len() - offset - 4)? + } else { + // Write a -1 to indicate NULL + // NOTE: It is illegal for [encode] to write any data + debug_assert_eq!(self.len(), offset + 4); + -1_i32 + }; + + // write the len to the beginning of the value + // (offset + 4) cannot overflow because it would have failed at `self.extend()`. + self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes()); + + Ok(()) + } + + // Adds a callback to be invoked later when we know the parameter type + #[allow(dead_code)] + pub(crate) fn patch(&mut self, callback: F) + where + F: Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync, + { + let offset = self.len(); + let arg_index = self.count; + + self.patches.push(Patch { + buf_offset: offset, + arg_index, + callback: Box::new(callback), + }); + } + + // Extends the inner buffer by enough space to have an OID + // Remembers where the OID goes and type name for the OID + pub(crate) fn patch_type_by_name(&mut self, type_name: &UStr) { + let offset = self.len(); + + self.extend_from_slice(&0_u32.to_be_bytes()); + self.type_holes.push(( + offset, + HoleKind::Type { + name: type_name.clone(), + }, + )); + } + + pub(crate) fn patch_array_type(&mut self, array: Arc) { + let offset = self.len(); + + self.extend_from_slice(&0_u32.to_be_bytes()); + self.type_holes.push((offset, HoleKind::Array(array))); + } + + fn snapshot(&self) -> PgArgumentBufferSnapshot { + let Self { + buffer, + count, + patches, + type_holes, + } = self; + + PgArgumentBufferSnapshot { + buffer_length: buffer.len(), + count: *count, + patches_length: patches.len(), + type_holes_length: type_holes.len(), + } + } + + fn reset_to_snapshot( + &mut self, + PgArgumentBufferSnapshot { + buffer_length, + count, + patches_length, + type_holes_length, + }: PgArgumentBufferSnapshot, + ) { + self.buffer.truncate(buffer_length); + self.count = count; + self.patches.truncate(patches_length); + self.type_holes.truncate(type_holes_length); + } +} + +struct PgArgumentBufferSnapshot { + buffer_length: usize, + count: usize, + patches_length: usize, + type_holes_length: usize, +} + +impl Deref for PgArgumentBuffer { + type Target = Vec; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.buffer + } +} + +impl DerefMut for PgArgumentBuffer { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.buffer + } +} + +pub(crate) fn value_size_int4_checked(size: usize) -> Result { + i32::try_from(size).map_err(|_| { + format!( + "value size would overflow in the binary protocol encoding: {size} > {}", + i32::MAX + ) + }) +} diff --git a/patches/sqlx-postgres/src/column.rs b/patches/sqlx-postgres/src/column.rs new file mode 100644 index 000000000..cc2d259a4 --- /dev/null +++ b/patches/sqlx-postgres/src/column.rs @@ -0,0 +1,32 @@ +use crate::ext::ustr::UStr; +use crate::{PgTypeInfo, Postgres}; + +pub(crate) use sqlx_core::column::{Column, ColumnIndex}; + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct PgColumn { + pub(crate) ordinal: usize, + pub(crate) name: UStr, + pub(crate) type_info: PgTypeInfo, + #[cfg_attr(feature = "offline", serde(skip))] + pub(crate) relation_id: Option, + #[cfg_attr(feature = "offline", serde(skip))] + pub(crate) relation_attribute_no: Option, +} + +impl Column for PgColumn { + type Database = Postgres; + + fn ordinal(&self) -> usize { + self.ordinal + } + + fn name(&self) -> &str { + &self.name + } + + fn type_info(&self) -> &PgTypeInfo { + &self.type_info + } +} diff --git a/patches/sqlx-postgres/src/connection/describe.rs b/patches/sqlx-postgres/src/connection/describe.rs new file mode 100644 index 000000000..9a46a202d --- /dev/null +++ b/patches/sqlx-postgres/src/connection/describe.rs @@ -0,0 +1,662 @@ +use crate::error::Error; +use crate::ext::ustr::UStr; +use crate::io::StatementId; +use crate::message::{ParameterDescription, RowDescription}; +use crate::query_as::query_as; +use crate::query_scalar::query_scalar; +use crate::statement::PgStatementMetadata; +use crate::type_info::{PgArrayOf, PgCustomType, PgType, PgTypeKind}; +use crate::types::Json; +use crate::types::Oid; +use crate::HashMap; +use crate::{PgColumn, PgConnection, PgTypeInfo}; +use futures_core::future::BoxFuture; +use smallvec::SmallVec; +use sqlx_core::query_builder::QueryBuilder; +use std::sync::Arc; + +/// Describes the type of the `pg_type.typtype` column +/// +/// See +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +enum TypType { + Base, + Composite, + Domain, + Enum, + Pseudo, + Range, +} + +impl TryFrom for TypType { + type Error = (); + + fn try_from(t: i8) -> Result { + let t = u8::try_from(t).or(Err(()))?; + + let t = match t { + b'b' => Self::Base, + b'c' => Self::Composite, + b'd' => Self::Domain, + b'e' => Self::Enum, + b'p' => Self::Pseudo, + b'r' => Self::Range, + _ => return Err(()), + }; + Ok(t) + } +} + +/// Describes the type of the `pg_type.typcategory` column +/// +/// See +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +enum TypCategory { + Array, + Boolean, + Composite, + DateTime, + Enum, + Geometric, + Network, + Numeric, + Pseudo, + Range, + String, + Timespan, + User, + BitString, + Unknown, +} + +impl TryFrom for TypCategory { + type Error = (); + + fn try_from(c: i8) -> Result { + let c = u8::try_from(c).or(Err(()))?; + + let c = match c { + b'A' => Self::Array, + b'B' => Self::Boolean, + b'C' => Self::Composite, + b'D' => Self::DateTime, + b'E' => Self::Enum, + b'G' => Self::Geometric, + b'I' => Self::Network, + b'N' => Self::Numeric, + b'P' => Self::Pseudo, + b'R' => Self::Range, + b'S' => Self::String, + b'T' => Self::Timespan, + b'U' => Self::User, + b'V' => Self::BitString, + b'X' => Self::Unknown, + _ => return Err(()), + }; + Ok(c) + } +} + +impl PgConnection { + pub(super) async fn handle_row_description( + &mut self, + desc: Option, + should_fetch: bool, + ) -> Result<(Vec, HashMap), Error> { + let mut columns = Vec::new(); + let mut column_names = HashMap::new(); + + let desc = if let Some(desc) = desc { + desc + } else { + // no rows + return Ok((columns, column_names)); + }; + + columns.reserve(desc.fields.len()); + column_names.reserve(desc.fields.len()); + + for (index, field) in desc.fields.into_iter().enumerate() { + let name = UStr::from(field.name); + + let type_info = self + .maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch) + .await?; + + let column = PgColumn { + ordinal: index, + name: name.clone(), + type_info, + relation_id: field.relation_id, + relation_attribute_no: field.relation_attribute_no, + }; + + columns.push(column); + column_names.insert(name, index); + } + + Ok((columns, column_names)) + } + + pub(super) async fn handle_parameter_description( + &mut self, + desc: ParameterDescription, + ) -> Result, Error> { + let mut params = Vec::with_capacity(desc.types.len()); + + for ty in desc.types { + params.push(self.maybe_fetch_type_info_by_oid(ty, true).await?); + } + + Ok(params) + } + + async fn maybe_fetch_type_info_by_oid( + &mut self, + oid: Oid, + should_fetch: bool, + ) -> Result { + // first we check if this is a built-in type + // in the average application, the vast majority of checks should flow through this + if let Some(info) = PgTypeInfo::try_from_oid(oid) { + return Ok(info); + } + + // next we check a local cache for user-defined type names <-> object id + if let Some(info) = self.cache_type_info.get(&oid) { + return Ok(info.clone()); + } + + // fallback to asking the database directly for a type name + if should_fetch { + let info = self.fetch_type_by_oid(oid).await?; + + // cache the type name <-> oid relationship in a paired hashmap + // so we don't come down this road again + self.cache_type_info.insert(oid, info.clone()); + self.cache_type_oid + .insert(info.0.name().to_string().into(), oid); + + Ok(info) + } else { + // we are not in a place that *can* run a query + // this generally means we are in the middle of another query + // this _should_ only happen for complex types sent through the TEXT protocol + // we're open to ideas to correct this.. but it'd probably be more efficient to figure + // out a way to "prime" the type cache for connections rather than make this + // fallback work correctly for complex user-defined types for the TEXT protocol + Ok(PgTypeInfo(PgType::DeclareWithOid(oid))) + } + } + + fn fetch_type_by_oid(&mut self, oid: Oid) -> BoxFuture<'_, Result> { + Box::pin(async move { + let (name, typ_type, category, relation_id, element, base_type): ( + String, + i8, + i8, + Oid, + Oid, + Oid, + ) = query_as( + // Converting the OID to `regtype` and then `text` will give us the name that + // the type will need to be found at by search_path. + "SELECT oid::regtype::text, \ + typtype, \ + typcategory, \ + typrelid, \ + typelem, \ + typbasetype \ + FROM pg_catalog.pg_type \ + WHERE oid = $1", + ) + .bind(oid) + .fetch_one(&mut *self) + .await?; + + let typ_type = TypType::try_from(typ_type); + let category = TypCategory::try_from(category); + + match (typ_type, category) { + (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await, + + (Ok(TypType::Base), Ok(TypCategory::Array)) => { + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + kind: PgTypeKind::Array( + self.maybe_fetch_type_info_by_oid(element, true).await?, + ), + name: name.into(), + oid, + })))) + } + + (Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => { + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + kind: PgTypeKind::Pseudo, + name: name.into(), + oid, + })))) + } + + (Ok(TypType::Range), Ok(TypCategory::Range)) => { + self.fetch_range_by_oid(oid, name).await + } + + (Ok(TypType::Enum), Ok(TypCategory::Enum)) => { + self.fetch_enum_by_oid(oid, name).await + } + + (Ok(TypType::Composite), Ok(TypCategory::Composite)) => { + self.fetch_composite_by_oid(oid, relation_id, name).await + } + + _ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + kind: PgTypeKind::Simple, + name: name.into(), + oid, + })))), + } + }) + } + + async fn fetch_enum_by_oid(&mut self, oid: Oid, name: String) -> Result { + let variants: Vec = query_scalar( + r#" +SELECT enumlabel +FROM pg_catalog.pg_enum +WHERE enumtypid = $1 +ORDER BY enumsortorder + "#, + ) + .bind(oid) + .fetch_all(self) + .await?; + + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + oid, + name: name.into(), + kind: PgTypeKind::Enum(Arc::from(variants)), + })))) + } + + fn fetch_composite_by_oid( + &mut self, + oid: Oid, + relation_id: Oid, + name: String, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let raw_fields: Vec<(String, Oid)> = query_as( + r#" +SELECT attname, atttypid +FROM pg_catalog.pg_attribute +WHERE attrelid = $1 +AND NOT attisdropped +AND attnum > 0 +ORDER BY attnum + "#, + ) + .bind(relation_id) + .fetch_all(&mut *self) + .await?; + + let mut fields = Vec::new(); + + for (field_name, field_oid) in raw_fields.into_iter() { + let field_type = self.maybe_fetch_type_info_by_oid(field_oid, true).await?; + + fields.push((field_name, field_type)); + } + + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + oid, + name: name.into(), + kind: PgTypeKind::Composite(Arc::from(fields)), + })))) + }) + } + + fn fetch_domain_by_oid( + &mut self, + oid: Oid, + base_type: Oid, + name: String, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?; + + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + oid, + name: name.into(), + kind: PgTypeKind::Domain(base_type), + })))) + }) + } + + fn fetch_range_by_oid( + &mut self, + oid: Oid, + name: String, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let element_oid: Oid = query_scalar( + r#" +SELECT rngsubtype +FROM pg_catalog.pg_range +WHERE rngtypid = $1 + "#, + ) + .bind(oid) + .fetch_one(&mut *self) + .await?; + + let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?; + + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + kind: PgTypeKind::Range(element), + name: name.into(), + oid, + })))) + }) + } + + pub(crate) async fn resolve_type_id(&mut self, ty: &PgType) -> Result { + if let Some(oid) = ty.try_oid() { + return Ok(oid); + } + + match ty { + PgType::DeclareWithName(name) => self.fetch_type_id_by_name(name).await, + PgType::DeclareArrayOf(array) => self.fetch_array_type_id(array).await, + // `.try_oid()` should return `Some()` or it should be covered here + _ => unreachable!("(bug) OID should be resolvable for type {ty:?}"), + } + } + + pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result { + if let Some(oid) = self.cache_type_oid.get(name) { + return Ok(*oid); + } + + // language=SQL + let (oid,): (Oid,) = query_as("SELECT $1::regtype::oid") + .bind(name) + .fetch_optional(&mut *self) + .await? + .ok_or_else(|| Error::TypeNotFound { + type_name: name.into(), + })?; + + self.cache_type_oid.insert(name.to_string().into(), oid); + Ok(oid) + } + + pub(crate) async fn fetch_array_type_id(&mut self, array: &PgArrayOf) -> Result { + if let Some(oid) = self + .cache_type_oid + .get(&array.elem_name) + .and_then(|elem_oid| self.cache_elem_type_to_array.get(elem_oid)) + { + return Ok(*oid); + } + + // language=SQL + let (elem_oid, array_oid): (Oid, Oid) = + query_as("SELECT oid, typarray FROM pg_catalog.pg_type WHERE oid = $1::regtype::oid") + .bind(&*array.elem_name) + .fetch_optional(&mut *self) + .await? + .ok_or_else(|| Error::TypeNotFound { + type_name: array.name.to_string(), + })?; + + // Avoids copying `elem_name` until necessary + self.cache_type_oid + .entry_ref(&array.elem_name) + .insert(elem_oid); + self.cache_elem_type_to_array.insert(elem_oid, array_oid); + + Ok(array_oid) + } + + pub(crate) async fn get_nullable_for_columns( + &mut self, + stmt_id: StatementId, + meta: &PgStatementMetadata, + ) -> Result>, Error> { + if meta.columns.is_empty() { + return Ok(vec![]); + } + + if meta.columns.len() * 3 > 65535 { + tracing::debug!( + ?stmt_id, + num_columns = meta.columns.len(), + "number of columns in query is too large to pull nullability for" + ); + } + + // Query for NOT NULL constraints for each column in the query. + // + // This will include columns that don't have a `relation_id` (are not from a table); + // assuming those are a minority of columns, it's less code to _not_ work around it + // and just let Postgres return `NULL`. + let mut nullable_query = QueryBuilder::new("SELECT NOT pg_attribute.attnotnull FROM ( "); + + nullable_query.push_values(meta.columns.iter().zip(0i32..), |mut tuple, (column, i)| { + // ({i}::int4, {column.relation_id}::int4, {column.relation_attribute_no}::int2) + tuple.push_bind(i).push_unseparated("::int4"); + tuple + .push_bind(column.relation_id) + .push_unseparated("::int4"); + tuple + .push_bind(column.relation_attribute_no) + .push_unseparated("::int2"); + }); + + nullable_query.push( + ") as col(idx, table_id, col_idx) \ + LEFT JOIN pg_catalog.pg_attribute \ + ON table_id IS NOT NULL \ + AND attrelid = table_id \ + AND attnum = col_idx \ + ORDER BY col.idx", + ); + + let mut nullables: Vec> = nullable_query + .build_query_scalar() + .fetch_all(&mut *self) + .await + .map_err(|e| { + err_protocol!( + "error from nullables query: {e}; query: {:?}", + nullable_query.sql() + ) + })?; + + // If the server is CockroachDB or Materialize, skip this step (#1248). + if !self.stream.parameter_statuses.contains_key("crdb_version") + && !self.stream.parameter_statuses.contains_key("mz_version") + { + // patch up our null inference with data from EXPLAIN + let nullable_patch = self + .nullables_from_explain(stmt_id, meta.parameters.len()) + .await?; + + for (nullable, patch) in nullables.iter_mut().zip(nullable_patch) { + *nullable = patch.or(*nullable); + } + } + + Ok(nullables) + } + + /// Infer nullability for columns of this statement using EXPLAIN VERBOSE. + /// + /// This currently only marks columns that are on the inner half of an outer join + /// and returns `None` for all others. + async fn nullables_from_explain( + &mut self, + stmt_id: StatementId, + params_len: usize, + ) -> Result>, Error> { + let stmt_id_display = stmt_id + .display() + .ok_or_else(|| err_protocol!("cannot EXPLAIN unnamed statement: {stmt_id:?}"))?; + + let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE {stmt_id_display}"); + let mut comma = false; + + if params_len > 0 { + explain += "("; + + // fill the arguments list with NULL, which should theoretically be valid + for _ in 0..params_len { + if comma { + explain += ", "; + } + + explain += "NULL"; + comma = true; + } + + explain += ")"; + } + + let (Json(explains),): (Json>,) = + query_as(&explain).fetch_one(self).await?; + + let mut nullables = Vec::new(); + + if let Some(Explain::Plan { + plan: + plan @ Plan { + output: Some(ref outputs), + .. + }, + }) = explains.first() + { + nullables.resize(outputs.len(), None); + visit_plan(plan, outputs, &mut nullables); + } + + Ok(nullables) + } +} + +fn visit_plan(plan: &Plan, outputs: &[String], nullables: &mut Vec>) { + if let Some(plan_outputs) = &plan.output { + // all outputs of a Full Join must be marked nullable + // otherwise, all outputs of the inner half of an outer join must be marked nullable + if plan.join_type.as_deref() == Some("Full") + || plan.parent_relation.as_deref() == Some("Inner") + { + for output in plan_outputs { + if let Some(i) = outputs.iter().position(|o| o == output) { + // N.B. this may produce false positives but those don't cause runtime errors + nullables[i] = Some(true); + } + } + } + } + + if let Some(plans) = &plan.plans { + if let Some("Left") | Some("Right") = plan.join_type.as_deref() { + for plan in plans { + visit_plan(plan, outputs, nullables); + } + } + } +} + +#[derive(serde::Deserialize, Debug)] +#[serde(untagged)] +enum Explain { + // NOTE: the returned JSON may not contain a `plan` field, for example, with `CALL` statements: + // https://github.com/launchbadge/sqlx/issues/1449 + // + // In this case, we should just fall back to assuming all is nullable. + // + // It may also contain additional fields we don't care about, which should not break parsing: + // https://github.com/launchbadge/sqlx/issues/2587 + // https://github.com/launchbadge/sqlx/issues/2622 + Plan { + #[serde(rename = "Plan")] + plan: Plan, + }, + + // This ensures that parsing never technically fails. + // + // We don't want to specifically expect `"Utility Statement"` because there might be other cases + // and we don't care unless it contains a query plan anyway. + Other(serde::de::IgnoredAny), +} + +#[derive(serde::Deserialize, Debug)] +struct Plan { + #[serde(rename = "Join Type")] + join_type: Option, + #[serde(rename = "Parent Relationship")] + parent_relation: Option, + #[serde(rename = "Output")] + output: Option>, + #[serde(rename = "Plans")] + plans: Option>, +} + +#[test] +fn explain_parsing() { + let normal_plan = r#"[ + { + "Plan": { + "Node Type": "Result", + "Parallel Aware": false, + "Async Capable": false, + "Startup Cost": 0.00, + "Total Cost": 0.01, + "Plan Rows": 1, + "Plan Width": 4, + "Output": ["1"] + } + } +]"#; + + // https://github.com/launchbadge/sqlx/issues/2622 + let extra_field = r#"[ + { + "Plan": { + "Node Type": "Result", + "Parallel Aware": false, + "Async Capable": false, + "Startup Cost": 0.00, + "Total Cost": 0.01, + "Plan Rows": 1, + "Plan Width": 4, + "Output": ["1"] + }, + "Query Identifier": 1147616880456321454 + } +]"#; + + // https://github.com/launchbadge/sqlx/issues/1449 + let utility_statement = r#"["Utility Statement"]"#; + + let normal_plan_parsed = serde_json::from_str::<[Explain; 1]>(normal_plan).unwrap(); + let extra_field_parsed = serde_json::from_str::<[Explain; 1]>(extra_field).unwrap(); + let utility_statement_parsed = serde_json::from_str::<[Explain; 1]>(utility_statement).unwrap(); + + assert!( + matches!(normal_plan_parsed, [Explain::Plan { plan: Plan { .. } }]), + "unexpected parse from {normal_plan:?}: {normal_plan_parsed:?}" + ); + + assert!( + matches!(extra_field_parsed, [Explain::Plan { plan: Plan { .. } }]), + "unexpected parse from {extra_field:?}: {extra_field_parsed:?}" + ); + + assert!( + matches!(utility_statement_parsed, [Explain::Other(_)]), + "unexpected parse from {utility_statement:?}: {utility_statement_parsed:?}" + ) +} diff --git a/patches/sqlx-postgres/src/connection/establish.rs b/patches/sqlx-postgres/src/connection/establish.rs new file mode 100644 index 000000000..a730f5c16 --- /dev/null +++ b/patches/sqlx-postgres/src/connection/establish.rs @@ -0,0 +1,151 @@ +use crate::HashMap; + +use crate::common::StatementCache; +use crate::connection::{sasl, stream::PgStream}; +use crate::error::Error; +use crate::io::StatementId; +use crate::message::{ + Authentication, BackendKeyData, BackendMessageFormat, Password, ReadyForQuery, Startup, +}; +use crate::{PgConnectOptions, PgConnection}; + +// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3 +// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.11 + +impl PgConnection { + pub(crate) async fn establish(options: &PgConnectOptions) -> Result { + // Upgrade to TLS if we were asked to and the server supports it + let mut stream = PgStream::connect(options).await?; + + // To begin a session, a frontend opens a connection to the server + // and sends a startup message. + + let mut params = vec![ + // Sets the display format for date and time values, + // as well as the rules for interpreting ambiguous date input values. + ("DateStyle", "ISO, MDY"), + // Sets the client-side encoding (character set). + // + ("client_encoding", "UTF8"), + // Sets the time zone for displaying and interpreting time stamps. + ("TimeZone", "UTC"), + ]; + + if let Some(ref extra_float_digits) = options.extra_float_digits { + params.push(("extra_float_digits", extra_float_digits)); + } + + if let Some(ref application_name) = options.application_name { + params.push(("application_name", application_name)); + } + + if let Some(ref options) = options.options { + params.push(("options", options)); + } + + stream.write(Startup { + username: Some(&options.username), + database: options.database.as_deref(), + params: ¶ms, + })?; + + stream.flush().await?; + + // The server then uses this information and the contents of + // its configuration files (such as pg_hba.conf) to determine whether the connection is + // provisionally acceptable, and what additional + // authentication is required (if any). + + let mut process_id = 0; + let mut secret_key = 0; + let transaction_status; + + loop { + let message = stream.recv().await?; + match message.format { + BackendMessageFormat::Authentication => match message.decode()? { + Authentication::Ok => { + // the authentication exchange is successfully completed + // do nothing; no more information is required to continue + } + + Authentication::CleartextPassword => { + // The frontend must now send a [PasswordMessage] containing the + // password in clear-text form. + + stream + .send(Password::Cleartext( + options.password.as_deref().unwrap_or_default(), + )) + .await?; + } + + Authentication::Md5Password(body) => { + // The frontend must now send a [PasswordMessage] containing the + // password (with user name) encrypted via MD5, then encrypted again + // using the 4-byte random salt specified in the + // [AuthenticationMD5Password] message. + + stream + .send(Password::Md5 { + username: &options.username, + password: options.password.as_deref().unwrap_or_default(), + salt: body.salt, + }) + .await?; + } + + Authentication::Sasl(body) => { + sasl::authenticate(&mut stream, options, body).await?; + } + + method => { + return Err(err_protocol!( + "unsupported authentication method: {:?}", + method + )); + } + }, + + BackendMessageFormat::BackendKeyData => { + // provides secret-key data that the frontend must save if it wants to be + // able to issue cancel requests later + + let data: BackendKeyData = message.decode()?; + + process_id = data.process_id; + secret_key = data.secret_key; + } + + BackendMessageFormat::ReadyForQuery => { + // start-up is completed. The frontend can now issue commands + transaction_status = message.decode::()?.transaction_status; + + break; + } + + _ => { + return Err(err_protocol!( + "establish: unexpected message: {:?}", + message.format + )) + } + } + } + + Ok(PgConnection { + stream, + process_id, + secret_key, + transaction_status, + transaction_depth: 0, + pending_ready_for_query_count: 0, + next_statement_id: StatementId::NAMED_START, + cache_statement: StatementCache::new(options.statement_cache_capacity), + cache_type_oid: HashMap::new(), + cache_type_info: HashMap::new(), + cache_elem_type_to_array: HashMap::new(), + log_settings: options.log_settings.clone(), + }) + } +} diff --git a/patches/sqlx-postgres/src/connection/executor.rs b/patches/sqlx-postgres/src/connection/executor.rs new file mode 100644 index 000000000..d2f6bcddf --- /dev/null +++ b/patches/sqlx-postgres/src/connection/executor.rs @@ -0,0 +1,472 @@ +use crate::describe::Describe; +use crate::error::Error; +use crate::executor::{Execute, Executor}; +use crate::io::{PortalId, StatementId}; +use crate::logger::QueryLogger; +use crate::message::{ + self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse, + ParseComplete, Query, RowDescription, +}; +use crate::statement::PgStatementMetadata; +use crate::{ + statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo, + PgValueFormat, Postgres, +}; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; +use futures_core::Stream; +use futures_util::{pin_mut, TryStreamExt}; +use sqlx_core::arguments::Arguments; +use sqlx_core::Either; +use std::{borrow::Cow, sync::Arc}; + +async fn prepare( + conn: &mut PgConnection, + sql: &str, + parameters: &[PgTypeInfo], + metadata: Option>, +) -> Result<(StatementId, Arc), Error> { + let id = conn.next_statement_id; + conn.next_statement_id = id.next(); + + // build a list of type OIDs to send to the database in the PARSE command + // we have not yet started the query sequence, so we are *safe* to cleanly make + // additional queries here to get any missing OIDs + + let mut param_types = Vec::with_capacity(parameters.len()); + + for ty in parameters { + param_types.push(conn.resolve_type_id(&ty.0).await?); + } + + // flush and wait until we are re-ready + conn.wait_until_ready().await?; + + // next we send the PARSE command to the server + conn.stream.write_msg(Parse { + param_types: ¶m_types, + query: sql, + statement: id, + })?; + + if metadata.is_none() { + // get the statement columns and parameters + conn.stream.write_msg(message::Describe::Statement(id))?; + } + + // we ask for the server to immediately send us the result of the PARSE command + conn.write_sync(); + conn.stream.flush().await?; + + // indicates that the SQL query string is now successfully parsed and has semantic validity + conn.stream.recv_expect::().await?; + + let metadata = if let Some(metadata) = metadata { + // each SYNC produces one READY FOR QUERY + conn.recv_ready_for_query().await?; + + // we already have metadata + metadata + } else { + let parameters = recv_desc_params(conn).await?; + + let rows = recv_desc_rows(conn).await?; + + // each SYNC produces one READY FOR QUERY + conn.recv_ready_for_query().await?; + + let parameters = conn.handle_parameter_description(parameters).await?; + + let (columns, column_names) = conn.handle_row_description(rows, true).await?; + + // ensure that if we did fetch custom data, we wait until we are fully ready before + // continuing + conn.wait_until_ready().await?; + + Arc::new(PgStatementMetadata { + parameters, + columns, + column_names: Arc::new(column_names), + }) + }; + + Ok((id, metadata)) +} + +async fn recv_desc_params(conn: &mut PgConnection) -> Result { + conn.stream.recv_expect().await +} + +async fn recv_desc_rows(conn: &mut PgConnection) -> Result, Error> { + let rows: Option = match conn.stream.recv().await? { + // describes the rows that will be returned when the statement is eventually executed + message if message.format == BackendMessageFormat::RowDescription => { + Some(message.decode()?) + } + + // no data would be returned if this statement was executed + message if message.format == BackendMessageFormat::NoData => None, + + message => { + return Err(err_protocol!( + "expecting RowDescription or NoData but received {:?}", + message.format + )); + } + }; + + Ok(rows) +} + +impl PgConnection { + // wait for CloseComplete to indicate a statement was closed + pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> { + // we need to wait for the [CloseComplete] to be returned from the server + while count > 0 { + match self.stream.recv().await? { + message if message.format == BackendMessageFormat::PortalSuspended => { + // there was an open portal + // this can happen if the last time a statement was used it was not fully executed + } + + message if message.format == BackendMessageFormat::CloseComplete => { + // successfully closed the statement (and freed up the server resources) + count -= 1; + } + + message => { + return Err(err_protocol!( + "expecting PortalSuspended or CloseComplete but received {:?}", + message.format + )); + } + } + } + + Ok(()) + } + + #[inline(always)] + pub(crate) fn write_sync(&mut self) { + self.stream + .write_msg(message::Sync) + .expect("BUG: Sync should not be too big for protocol"); + + // all SYNC messages will return a ReadyForQuery + self.pending_ready_for_query_count += 1; + } + + async fn get_or_prepare<'a>( + &mut self, + sql: &str, + parameters: &[PgTypeInfo], + // should we store the result of this prepare to the cache + store_to_cache: bool, + // optional metadata that was provided by the user, this means they are reusing + // a statement object + metadata: Option>, + ) -> Result<(StatementId, Arc), Error> { + if let Some(statement) = self.cache_statement.get_mut(sql) { + return Ok((*statement).clone()); + } + + let statement = prepare(self, sql, parameters, metadata).await?; + + if store_to_cache && self.cache_statement.is_enabled() { + if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) { + self.stream.write_msg(Close::Statement(id))?; + self.write_sync(); + + self.stream.flush().await?; + + self.wait_for_close_complete(1).await?; + self.recv_ready_for_query().await?; + } + } + + Ok(statement) + } + + pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>( + &'c mut self, + query: &'q str, + arguments: Option, + limit: u8, + persistent: bool, + metadata_opt: Option>, + ) -> Result, Error>> + 'e, Error> { + let mut logger = QueryLogger::new(query, self.log_settings.clone()); + + // before we continue, wait until we are "ready" to accept more queries + self.wait_until_ready().await?; + + let mut metadata: Arc; + + let format = if let Some(mut arguments) = arguments { + // Check this before we write anything to the stream. + let num_params = i16::try_from(arguments.len()).map_err(|_| { + err_protocol!( + "PgConnection::run(): too many arguments for query: {}", + arguments.len() + ) + })?; + + // prepare the statement if this our first time executing it + // always return the statement ID here + let (statement, metadata_) = self + .get_or_prepare(query, &arguments.types, persistent, metadata_opt) + .await?; + + metadata = metadata_; + + // patch holes created during encoding + arguments.apply_patches(self, &metadata.parameters).await?; + + // consume messages till `ReadyForQuery` before bind and execute + self.wait_until_ready().await?; + + // bind to attach the arguments to the statement and create a portal + self.stream.write_msg(Bind { + portal: PortalId::UNNAMED, + statement, + formats: &[PgValueFormat::Binary], + num_params, + params: &arguments.buffer, + result_formats: &[PgValueFormat::Binary], + })?; + + // executes the portal up to the passed limit + // the protocol-level limit acts nearly identically to the `LIMIT` in SQL + self.stream.write_msg(message::Execute { + portal: PortalId::UNNAMED, + limit: limit.into(), + })?; + // From https://www.postgresql.org/docs/current/protocol-flow.html: + // + // "An unnamed portal is destroyed at the end of the transaction, or as + // soon as the next Bind statement specifying the unnamed portal as + // destination is issued. (Note that a simple Query message also + // destroys the unnamed portal." + + // we ask the database server to close the unnamed portal and free the associated resources + // earlier - after the execution of the current query. + self.stream.write_msg(Close::Portal(PortalId::UNNAMED))?; + + // finally, [Sync] asks postgres to process the messages that we sent and respond with + // a [ReadyForQuery] message when it's completely done. Theoretically, we could send + // dozens of queries before a [Sync] and postgres can handle that. Execution on the server + // is still serial but it would reduce round-trips. Some kind of builder pattern that is + // termed batching might suit this. + self.write_sync(); + + // prepared statements are binary + PgValueFormat::Binary + } else { + // Query will trigger a ReadyForQuery + self.stream.write_msg(Query(query))?; + self.pending_ready_for_query_count += 1; + + // metadata starts out as "nothing" + metadata = Arc::new(PgStatementMetadata::default()); + + // and unprepared statements are text + PgValueFormat::Text + }; + + self.stream.flush().await?; + + Ok(try_stream! { + loop { + let message = self.stream.recv().await?; + + match message.format { + BackendMessageFormat::BindComplete + | BackendMessageFormat::ParseComplete + | BackendMessageFormat::ParameterDescription + | BackendMessageFormat::NoData + // unnamed portal has been closed + | BackendMessageFormat::CloseComplete + => { + // harmless messages to ignore + } + + // "Execute phase is always terminated by the appearance of + // exactly one of these messages: CommandComplete, + // EmptyQueryResponse (if the portal was created from an + // empty query string), ErrorResponse, or PortalSuspended" + BackendMessageFormat::CommandComplete => { + // a SQL command completed normally + let cc: CommandComplete = message.decode()?; + + let rows_affected = cc.rows_affected(); + logger.increase_rows_affected(rows_affected); + r#yield!(Either::Left(PgQueryResult { + rows_affected, + })); + } + + BackendMessageFormat::EmptyQueryResponse => { + // empty query string passed to an unprepared execute + } + + // Message::ErrorResponse is handled in self.stream.recv() + + // incomplete query execution has finished + BackendMessageFormat::PortalSuspended => {} + + BackendMessageFormat::RowDescription => { + // indicates that a *new* set of rows are about to be returned + let (columns, column_names) = self + .handle_row_description(Some(message.decode()?), false) + .await?; + + metadata = Arc::new(PgStatementMetadata { + column_names: Arc::new(column_names), + columns, + parameters: Vec::default(), + }); + } + + BackendMessageFormat::DataRow => { + logger.increment_rows_returned(); + + // one of the set of rows returned by a SELECT, FETCH, etc query + let data: DataRow = message.decode()?; + let row = PgRow { + data, + format, + metadata: Arc::clone(&metadata), + }; + + r#yield!(Either::Right(row)); + } + + BackendMessageFormat::ReadyForQuery => { + // processing of the query string is complete + self.handle_ready_for_query(message)?; + break; + } + + _ => { + return Err(err_protocol!( + "execute: unexpected message: {:?}", + message.format + )); + } + } + } + + Ok(()) + }) + } +} + +impl<'c> Executor<'c> for &'c mut PgConnection { + type Database = Postgres; + + fn fetch_many<'e, 'q, E>( + self, + mut query: E, + ) -> BoxStream<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + 'q: 'e, + E: 'q, + { + let sql = query.sql(); + // False positive: https://github.com/rust-lang/rust-clippy/issues/12560 + #[allow(clippy::map_clone)] + let metadata = query.statement().map(|s| Arc::clone(&s.metadata)); + let arguments = query.take_arguments().map_err(Error::Encode); + let persistent = query.persistent(); + + Box::pin(try_stream! { + let arguments = arguments?; + let s = self.run(sql, arguments, 0, persistent, metadata).await?; + pin_mut!(s); + + while let Some(v) = s.try_next().await? { + r#yield!(v); + } + + Ok(()) + }) + } + + fn fetch_optional<'e, 'q, E>(self, mut query: E) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + 'q: 'e, + E: 'q, + { + let sql = query.sql(); + // False positive: https://github.com/rust-lang/rust-clippy/issues/12560 + #[allow(clippy::map_clone)] + let metadata = query.statement().map(|s| Arc::clone(&s.metadata)); + let arguments = query.take_arguments().map_err(Error::Encode); + let persistent = query.persistent(); + + Box::pin(async move { + let arguments = arguments?; + let s = self.run(sql, arguments, 1, persistent, metadata).await?; + pin_mut!(s); + + // With deferred constraints we need to check all responses as we + // could get a OK response (with uncommitted data), only to get an + // error response after (when the deferred constraint is actually + // checked). + let mut ret = None; + while let Some(result) = s.try_next().await? { + match result { + Either::Right(r) if ret.is_none() => ret = Some(r), + _ => {} + } + } + Ok(ret) + }) + } + + fn prepare_with<'e, 'q: 'e>( + self, + sql: &'q str, + parameters: &'e [PgTypeInfo], + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + { + Box::pin(async move { + self.wait_until_ready().await?; + + let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?; + + Ok(PgStatement { + sql: Cow::Borrowed(sql), + metadata, + }) + }) + } + + fn describe<'e, 'q: 'e>( + self, + sql: &'q str, + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + { + Box::pin(async move { + self.wait_until_ready().await?; + + let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?; + + let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?; + + Ok(Describe { + columns: metadata.columns.clone(), + nullable, + parameters: Some(Either::Left(metadata.parameters.clone())), + }) + }) + } +} diff --git a/patches/sqlx-postgres/src/connection/mod.rs b/patches/sqlx-postgres/src/connection/mod.rs new file mode 100644 index 000000000..5a6a597ea --- /dev/null +++ b/patches/sqlx-postgres/src/connection/mod.rs @@ -0,0 +1,232 @@ +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; + +use crate::HashMap; +use futures_core::future::BoxFuture; +use futures_util::FutureExt; + +use crate::common::StatementCache; +use crate::error::Error; +use crate::ext::ustr::UStr; +use crate::io::StatementId; +use crate::message::{ + BackendMessageFormat, Close, Query, ReadyForQuery, ReceivedMessage, Terminate, + TransactionStatus, +}; +use crate::statement::PgStatementMetadata; +use crate::transaction::Transaction; +use crate::types::Oid; +use crate::{PgConnectOptions, PgTypeInfo, Postgres}; + +pub(crate) use sqlx_core::connection::*; + +pub use self::stream::PgStream; + +pub(crate) mod describe; +mod establish; +mod executor; +mod sasl; +mod stream; +mod tls; + +/// A connection to a PostgreSQL database. +pub struct PgConnection { + // underlying TCP or UDS stream, + // wrapped in a potentially TLS stream, + // wrapped in a buffered stream + pub(crate) stream: PgStream, + + // process id of this backend + // used to send cancel requests + #[allow(dead_code)] + process_id: u32, + + // secret key of this backend + // used to send cancel requests + #[allow(dead_code)] + secret_key: u32, + + // sequence of statement IDs for use in preparing statements + // in PostgreSQL, the statement is prepared to a user-supplied identifier + next_statement_id: StatementId, + + // cache statement by query string to the id and columns + cache_statement: StatementCache<(StatementId, Arc)>, + + // cache user-defined types by id <-> info + cache_type_info: HashMap, + cache_type_oid: HashMap, + cache_elem_type_to_array: HashMap, + + // number of ReadyForQuery messages that we are currently expecting + pub(crate) pending_ready_for_query_count: usize, + + // current transaction status + transaction_status: TransactionStatus, + pub(crate) transaction_depth: usize, + + log_settings: LogSettings, +} + +impl PgConnection { + /// the version number of the server in `libpq` format + pub fn server_version_num(&self) -> Option { + self.stream.server_version_num + } + + // will return when the connection is ready for another query + pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> { + if !self.stream.write_buffer_mut().is_empty() { + self.stream.flush().await?; + } + + while self.pending_ready_for_query_count > 0 { + let message = self.stream.recv().await?; + + if let BackendMessageFormat::ReadyForQuery = message.format { + self.handle_ready_for_query(message)?; + } + } + + Ok(()) + } + + async fn recv_ready_for_query(&mut self) -> Result<(), Error> { + let r: ReadyForQuery = self.stream.recv_expect().await?; + + self.pending_ready_for_query_count -= 1; + self.transaction_status = r.transaction_status; + + Ok(()) + } + + #[inline(always)] + fn handle_ready_for_query(&mut self, message: ReceivedMessage) -> Result<(), Error> { + self.pending_ready_for_query_count = self + .pending_ready_for_query_count + .checked_sub(1) + .ok_or_else(|| err_protocol!("received more ReadyForQuery messages than expected"))?; + + self.transaction_status = message.decode::()?.transaction_status; + + Ok(()) + } + + /// Queue a simple query (not prepared) to execute the next time this connection is used. + /// + /// Used for rolling back transactions and releasing advisory locks. + #[inline(always)] + pub(crate) fn queue_simple_query(&mut self, query: &str) -> Result<(), Error> { + self.stream.write_msg(Query(query))?; + self.pending_ready_for_query_count += 1; + + Ok(()) + } +} + +impl Debug for PgConnection { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("PgConnection").finish() + } +} + +impl Connection for PgConnection { + type Database = Postgres; + + type Options = PgConnectOptions; + + fn close(mut self) -> BoxFuture<'static, Result<(), Error>> { + // The normal, graceful termination procedure is that the frontend sends a Terminate + // message and immediately closes the connection. + + // On receipt of this message, the backend closes the + // connection and terminates. + + Box::pin(async move { + self.stream.send(Terminate).await?; + self.stream.shutdown().await?; + + Ok(()) + }) + } + + fn close_hard(mut self) -> BoxFuture<'static, Result<(), Error>> { + Box::pin(async move { + self.stream.shutdown().await?; + + Ok(()) + }) + } + + fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { + // Users were complaining about this showing up in query statistics on the server. + // By sending a comment we avoid an error if the connection was in the middle of a rowset + // self.execute("/* SQLx ping */").map_ok(|_| ()).boxed() + + Box::pin(async move { + // The simplest call-and-response that's possible. + self.write_sync(); + self.wait_until_ready().await + }) + } + + fn begin(&mut self) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self) + } + + fn cached_statements_size(&self) -> usize { + self.cache_statement.len() + } + + fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + self.cache_type_oid.clear(); + + let mut cleared = 0_usize; + + self.wait_until_ready().await?; + + while let Some((id, _)) = self.cache_statement.remove_lru() { + self.stream.write_msg(Close::Statement(id))?; + cleared += 1; + } + + if cleared > 0 { + self.write_sync(); + self.stream.flush().await?; + + self.wait_for_close_complete(cleared).await?; + self.recv_ready_for_query().await?; + } + + Ok(()) + }) + } + + fn shrink_buffers(&mut self) { + self.stream.shrink_buffers(); + } + + #[doc(hidden)] + fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { + self.wait_until_ready().boxed() + } + + #[doc(hidden)] + fn should_flush(&self) -> bool { + !self.stream.write_buffer().is_empty() + } +} + +// Implement `AsMut` so that `PgConnection` can be wrapped in +// a `PgAdvisoryLockGuard`. +// +// See: https://github.com/launchbadge/sqlx/issues/2520 +impl AsMut for PgConnection { + fn as_mut(&mut self) -> &mut PgConnection { + self + } +} diff --git a/patches/sqlx-postgres/src/connection/sasl.rs b/patches/sqlx-postgres/src/connection/sasl.rs new file mode 100644 index 000000000..729cc1fcc --- /dev/null +++ b/patches/sqlx-postgres/src/connection/sasl.rs @@ -0,0 +1,225 @@ +use crate::connection::stream::PgStream; +use crate::error::Error; +use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse}; +use crate::PgConnectOptions; +use hmac::{Hmac, Mac}; +use rand::Rng; +use sha2::{Digest, Sha256}; +use stringprep::saslprep; + +use base64::prelude::{Engine as _, BASE64_STANDARD}; + +const GS2_HEADER: &str = "n,,"; +const CHANNEL_ATTR: &str = "c"; +const USERNAME_ATTR: &str = "n"; +const CLIENT_PROOF_ATTR: &str = "p"; +const NONCE_ATTR: &str = "r"; + +pub(crate) async fn authenticate( + stream: &mut PgStream, + options: &PgConnectOptions, + data: AuthenticationSasl, +) -> Result<(), Error> { + let mut has_sasl = false; + let mut has_sasl_plus = false; + let mut unknown = Vec::new(); + + for mechanism in data.mechanisms() { + match mechanism { + "SCRAM-SHA-256" => { + has_sasl = true; + } + + "SCRAM-SHA-256-PLUS" => { + has_sasl_plus = true; + } + + _ => { + unknown.push(mechanism.to_owned()); + } + } + } + + if !has_sasl_plus && !has_sasl { + return Err(err_protocol!( + "unsupported SASL authentication mechanisms: {}", + unknown.join(", ") + )); + } + + // channel-binding = "c=" base64 + let mut channel_binding = format!("{CHANNEL_ATTR}="); + BASE64_STANDARD.encode_string(GS2_HEADER, &mut channel_binding); + + // "n=" saslname ;; Usernames are prepared using SASLprep. + let username = format!("{}={}", USERNAME_ATTR, options.username); + let username = match saslprep(&username) { + Ok(v) => v, + // TODO(danielakhterov): Remove panic when we have proper support for configuration errors + Err(_) => panic!("Failed to saslprep username"), + }; + + // nonce = "r=" c-nonce [s-nonce] ;; Second part provided by server. + let nonce = gen_nonce(); + + // client-first-message-bare = [reserved-mext ","] username "," nonce ["," extensions] + let client_first_message_bare = format!("{username},{nonce}"); + + let client_first_message = format!("{GS2_HEADER}{client_first_message_bare}"); + + stream + .send(SaslInitialResponse { + response: &client_first_message, + plus: false, + }) + .await?; + + let cont = match stream.recv_expect().await? { + Authentication::SaslContinue(data) => data, + + auth => { + return Err(err_protocol!( + "expected SASLContinue but received {:?}", + auth + )); + } + }; + + // SaltedPassword := Hi(Normalize(password), salt, i) + let salted_password = hi( + options.password.as_deref().unwrap_or_default(), + &cont.salt, + cont.iterations, + )?; + + // ClientKey := HMAC(SaltedPassword, "Client Key") + let mut mac = Hmac::::new_from_slice(&salted_password).map_err(Error::protocol)?; + mac.update(b"Client Key"); + + let client_key = mac.finalize().into_bytes(); + + // StoredKey := H(ClientKey) + let stored_key = Sha256::digest(client_key); + + // client-final-message-without-proof + let client_final_message_wo_proof = format!( + "{channel_binding},r={nonce}", + channel_binding = channel_binding, + nonce = &cont.nonce + ); + + // AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof + let auth_message = format!( + "{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}", + client_first_message_bare = client_first_message_bare, + server_first_message = cont.message, + client_final_message_wo_proof = client_final_message_wo_proof + ); + + // ClientSignature := HMAC(StoredKey, AuthMessage) + let mut mac = Hmac::::new_from_slice(&stored_key).map_err(Error::protocol)?; + mac.update(auth_message.as_bytes()); + + let client_signature = mac.finalize().into_bytes(); + + // ClientProof := ClientKey XOR ClientSignature + let client_proof: Vec = client_key + .iter() + .zip(client_signature.iter()) + .map(|(&a, &b)| a ^ b) + .collect(); + + // ServerKey := HMAC(SaltedPassword, "Server Key") + let mut mac = Hmac::::new_from_slice(&salted_password).map_err(Error::protocol)?; + mac.update(b"Server Key"); + + let server_key = mac.finalize().into_bytes(); + + // ServerSignature := HMAC(ServerKey, AuthMessage) + let mut mac = Hmac::::new_from_slice(&server_key).map_err(Error::protocol)?; + mac.update(auth_message.as_bytes()); + + // client-final-message = client-final-message-without-proof "," proof + let mut client_final_message = format!("{client_final_message_wo_proof},{CLIENT_PROOF_ATTR}="); + BASE64_STANDARD.encode_string(client_proof, &mut client_final_message); + + stream.send(SaslResponse(&client_final_message)).await?; + + let data = match stream.recv_expect().await? { + Authentication::SaslFinal(data) => data, + + auth => { + return Err(err_protocol!("expected SASLFinal but received {:?}", auth)); + } + }; + + // authentication is only considered valid if this verification passes + mac.verify_slice(&data.verifier).map_err(Error::protocol)?; + + Ok(()) +} + +// nonce is a sequence of random printable bytes +fn gen_nonce() -> String { + let mut rng = rand::thread_rng(); + let count = rng.gen_range(64..128); + + // printable = %x21-2B / %x2D-7E + // ;; Printable ASCII except ",". + // ;; Note that any "printable" is also + // ;; a valid "value". + let nonce: String = std::iter::repeat(()) + .map(|()| { + let mut c = rng.gen_range(0x21u8..0x7F); + + while c == 0x2C { + c = rng.gen_range(0x21u8..0x7F); + } + + c + }) + .take(count) + .map(|c| c as char) + .collect(); + + rng.gen_range(32..128); + format!("{NONCE_ATTR}={nonce}") +} + +// Hi(str, salt, i): +fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error> { + let mut mac = Hmac::::new_from_slice(s.as_bytes()).map_err(Error::protocol)?; + + mac.update(salt); + mac.update(&1u32.to_be_bytes()); + + let mut u = mac.finalize_reset().into_bytes(); + let mut hi = u; + + for _ in 1..iter_count { + mac.update(u.as_slice()); + u = mac.finalize_reset().into_bytes(); + hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect(); + } + + Ok(hi.into()) +} + +#[cfg(all(test, not(debug_assertions)))] +#[bench] +fn bench_sasl_hi(b: &mut test::Bencher) { + use test::black_box; + + let mut rng = rand::thread_rng(); + let nonce: Vec = std::iter::repeat(()) + .map(|()| rng.sample(rand::distributions::Alphanumeric)) + .take(64) + .collect(); + b.iter(|| { + let _ = hi( + test::black_box("secret_password"), + test::black_box(&nonce), + test::black_box(4096), + ); + }); +} diff --git a/patches/sqlx-postgres/src/connection/stream.rs b/patches/sqlx-postgres/src/connection/stream.rs new file mode 100644 index 000000000..781739992 --- /dev/null +++ b/patches/sqlx-postgres/src/connection/stream.rs @@ -0,0 +1,253 @@ +use std::collections::BTreeMap; +use std::ops::{Deref, DerefMut}; +use std::str::FromStr; + +use futures_channel::mpsc::UnboundedSender; +use futures_util::SinkExt; +use log::Level; +use sqlx_core::bytes::{Buf, Bytes}; + +use crate::connection::tls::MaybeUpgradeTls; +use crate::error::Error; +use crate::message::{ + BackendMessage, BackendMessageFormat, EncodeMessage, FrontendMessage, Notice, Notification, + ParameterStatus, ReceivedMessage, +}; +use crate::net::{self, BufferedSocket, Socket}; +use crate::{PgConnectOptions, PgDatabaseError, PgSeverity}; + +// the stream is a separate type from the connection to uphold the invariant where an instantiated +// [PgConnection] is a **valid** connection to postgres + +// when a new connection is asked for, we work directly on the [PgStream] type until the +// connection is fully established + +// in other words, `self` in any PgConnection method is a live connection to postgres that +// is fully prepared to receive queries + +pub struct PgStream { + // A trait object is okay here as the buffering amortizes the overhead of both the dynamic + // function call as well as the syscall. + inner: BufferedSocket>, + + // buffer of unreceived notification messages from `PUBLISH` + // this is set when creating a PgListener and only written to if that listener is + // re-used for query execution in-between receiving messages + pub(crate) notifications: Option>, + + pub(crate) parameter_statuses: BTreeMap, + + pub(crate) server_version_num: Option, +} + +impl PgStream { + pub(super) async fn connect(options: &PgConnectOptions) -> Result { + let socket_future = match options.fetch_socket() { + Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?, + None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?, + }; + + let socket = socket_future.await?; + + Ok(Self { + inner: BufferedSocket::new(socket), + notifications: None, + parameter_statuses: BTreeMap::default(), + server_version_num: None, + }) + } + + #[inline(always)] + pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> { + self.write(EncodeMessage(message)) + } + + pub(crate) async fn send(&mut self, message: T) -> Result<(), Error> + where + T: FrontendMessage, + { + self.write_msg(message)?; + self.flush().await?; + Ok(()) + } + + // Expect a specific type and format + pub(crate) async fn recv_expect(&mut self) -> Result { + self.recv().await?.decode() + } + + pub(crate) async fn recv_unchecked(&mut self) -> Result { + // all packets in postgres start with a 5-byte header + // this header contains the message type and the total length of the message + let mut header: Bytes = self.inner.read(5).await?; + + let format = BackendMessageFormat::try_from_u8(header.get_u8())?; + let size = (header.get_u32() - 4) as usize; + + let contents = self.inner.read(size).await?; + + Ok(ReceivedMessage { format, contents }) + } + + // Get the next message from the server + // May wait for more data from the server + pub(crate) async fn recv(&mut self) -> Result { + loop { + let message = self.recv_unchecked().await?; + + match message.format { + BackendMessageFormat::ErrorResponse => { + // An error returned from the database server. + return Err(message.decode::()?.into()); + } + + BackendMessageFormat::NotificationResponse => { + if let Some(buffer) = &mut self.notifications { + let notification: Notification = message.decode()?; + let _ = buffer.send(notification).await; + + continue; + } + } + + BackendMessageFormat::ParameterStatus => { + // informs the frontend about the current (initial) + // setting of backend parameters + + let ParameterStatus { name, value } = message.decode()?; + // TODO: handle `client_encoding`, `DateStyle` change + + match name.as_str() { + "server_version" => { + self.server_version_num = parse_server_version(&value); + } + _ => { + self.parameter_statuses.insert(name, value); + } + } + + continue; + } + + BackendMessageFormat::NoticeResponse => { + // do we need this to be more configurable? + // if you are reading this comment and think so, open an issue + + let notice: Notice = message.decode()?; + + let (log_level, tracing_level) = match notice.severity() { + PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => { + (Level::Error, tracing::Level::ERROR) + } + PgSeverity::Warning => (Level::Warn, tracing::Level::WARN), + PgSeverity::Notice => (Level::Info, tracing::Level::INFO), + PgSeverity::Debug => (Level::Debug, tracing::Level::DEBUG), + PgSeverity::Info | PgSeverity::Log => (Level::Trace, tracing::Level::TRACE), + }; + + let log_is_enabled = log::log_enabled!( + target: "sqlx::postgres::notice", + log_level + ) || sqlx_core::private_tracing_dynamic_enabled!( + target: "sqlx::postgres::notice", + tracing_level + ); + if log_is_enabled { + sqlx_core::private_tracing_dynamic_event!( + target: "sqlx::postgres::notice", + tracing_level, + message = notice.message() + ); + } + + continue; + } + + _ => {} + } + + return Ok(message); + } + } +} + +impl Deref for PgStream { + type Target = BufferedSocket>; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for PgStream { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +// reference: +// https://github.com/postgres/postgres/blob/6feebcb6b44631c3dc435e971bd80c2dd218a5ab/src/interfaces/libpq/fe-exec.c#L1030-L1065 +fn parse_server_version(s: &str) -> Option { + let mut parts = Vec::::with_capacity(3); + + let mut from = 0; + let mut chs = s.char_indices().peekable(); + while let Some((i, ch)) = chs.next() { + match ch { + '.' => { + if let Ok(num) = u32::from_str(&s[from..i]) { + parts.push(num); + from = i + 1; + } else { + break; + } + } + _ if ch.is_ascii_digit() => { + if chs.peek().is_none() { + if let Ok(num) = u32::from_str(&s[from..]) { + parts.push(num); + } + break; + } + } + _ => { + if let Ok(num) = u32::from_str(&s[from..i]) { + parts.push(num); + } + break; + } + }; + } + + let version_num = match parts.as_slice() { + [major, minor, rev] => (100 * major + minor) * 100 + rev, + [major, minor] if *major >= 10 => 100 * 100 * major + minor, + [major, minor] => (100 * major + minor) * 100, + [major] => 100 * 100 * major, + _ => return None, + }; + + Some(version_num) +} + +#[cfg(test)] +mod tests { + use super::parse_server_version; + + #[test] + fn test_parse_server_version_num() { + // old style + assert_eq!(parse_server_version("9.6.1"), Some(90601)); + // new style + assert_eq!(parse_server_version("10.1"), Some(100001)); + // old style without minor version + assert_eq!(parse_server_version("9.6devel"), Some(90600)); + // new style without minor version, e.g. */ + assert_eq!(parse_server_version("10devel"), Some(100000)); + assert_eq!(parse_server_version("13devel87"), Some(130000)); + // unknown + assert_eq!(parse_server_version("unknown"), None); + } +} diff --git a/patches/sqlx-postgres/src/connection/tls.rs b/patches/sqlx-postgres/src/connection/tls.rs new file mode 100644 index 000000000..04bab793a --- /dev/null +++ b/patches/sqlx-postgres/src/connection/tls.rs @@ -0,0 +1,102 @@ +use futures_core::future::BoxFuture; + +use crate::error::Error; +use crate::net::tls::{self, TlsConfig}; +use crate::net::{Socket, SocketIntoBox, WithSocket}; + +use crate::message::SslRequest; +use crate::{PgConnectOptions, PgSslMode}; + +pub struct MaybeUpgradeTls<'a>(pub &'a PgConnectOptions); + +impl<'a> WithSocket for MaybeUpgradeTls<'a> { + type Output = BoxFuture<'a, crate::Result>>; + + fn with_socket(self, socket: S) -> Self::Output { + Box::pin(maybe_upgrade(socket, self.0)) + } +} + +async fn maybe_upgrade( + mut socket: S, + options: &PgConnectOptions, +) -> Result, Error> { + // https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS + match options.ssl_mode { + // FIXME: Implement ALLOW + PgSslMode::Allow | PgSslMode::Disable => return Ok(Box::new(socket)), + + PgSslMode::Prefer => { + if !tls::available() { + return Ok(Box::new(socket)); + } + + // try upgrade, but its okay if we fail + if !request_upgrade(&mut socket, options).await? { + return Ok(Box::new(socket)); + } + } + + PgSslMode::Require | PgSslMode::VerifyFull | PgSslMode::VerifyCa => { + tls::error_if_unavailable()?; + + if !request_upgrade(&mut socket, options).await? { + // upgrade failed, die + return Err(Error::Tls("server does not support TLS".into())); + } + } + } + + let accept_invalid_certs = !matches!( + options.ssl_mode, + PgSslMode::VerifyCa | PgSslMode::VerifyFull + ); + let accept_invalid_hostnames = !matches!(options.ssl_mode, PgSslMode::VerifyFull); + + let config = TlsConfig { + accept_invalid_certs, + accept_invalid_hostnames, + hostname: &options.host, + root_cert_path: options.ssl_root_cert.as_ref(), + client_cert_path: options.ssl_client_cert.as_ref(), + client_key_path: options.ssl_client_key.as_ref(), + }; + + tls::handshake(socket, config, SocketIntoBox).await +} + +async fn request_upgrade( + socket: &mut impl Socket, + _options: &PgConnectOptions, +) -> Result { + // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.11 + + // To initiate an SSL-encrypted connection, the frontend initially sends an + // SSLRequest message rather than a StartupMessage + + socket.write(SslRequest::BYTES).await?; + + // The server then responds with a single byte containing S or N, indicating that + // it is willing or unwilling to perform SSL, respectively. + + let mut response = [0u8]; + + socket.read(&mut &mut response[..]).await?; + + match response[0] { + b'S' => { + // The server is ready and willing to accept an SSL connection + Ok(true) + } + + b'N' => { + // The server is _unwilling_ to perform SSL + Ok(false) + } + + other => Err(err_protocol!( + "unexpected response from SSLRequest: 0x{:02x}", + other + )), + } +} diff --git a/patches/sqlx-postgres/src/copy.rs b/patches/sqlx-postgres/src/copy.rs new file mode 100644 index 000000000..347877c36 --- /dev/null +++ b/patches/sqlx-postgres/src/copy.rs @@ -0,0 +1,342 @@ +use std::borrow::Cow; +use std::ops::{Deref, DerefMut}; + +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; + +use sqlx_core::bytes::{BufMut, Bytes}; + +use crate::connection::PgConnection; +use crate::error::{Error, Result}; +use crate::ext::async_stream::TryAsyncStream; +use crate::io::AsyncRead; +use crate::message::{ + BackendMessageFormat, CommandComplete, CopyData, CopyDone, CopyFail, CopyInResponse, + CopyOutResponse, CopyResponseData, Query, ReadyForQuery, +}; +use crate::pool::{Pool, PoolConnection}; +use crate::Postgres; + +impl PgConnection { + /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data + /// to Postgres. This is a more efficient way to import data into Postgres as compared to + /// `INSERT` but requires one of a few specific data formats (text/CSV/binary). + /// + /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is + /// returned. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// + /// + /// ### Note + /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection + /// will return an error the next time it is used. + pub async fn copy_in_raw(&mut self, statement: &str) -> Result> { + PgCopyIn::begin(self, statement).await + } + + /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data + /// from Postgres. This is a more efficient way to export data from Postgres but + /// arrives in chunks of one of a few data formats (text/CSV/binary). + /// + /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command, + /// an error is returned. + /// + /// Note that once this process has begun, unless you read the stream to completion, + /// it can only be canceled in two ways: + /// + /// 1. by closing the connection, or: + /// 2. by using another connection to kill the server process that is sending the data as shown + /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598). + /// + /// If you don't read the stream to completion, the next time the connection is used it will + /// need to read and discard all the remaining queued data, which could take some time. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// + #[allow(clippy::needless_lifetimes)] + pub async fn copy_out_raw<'c>( + &'c mut self, + statement: &str, + ) -> Result>> { + pg_begin_copy_out(self, statement).await + } +} + +/// Implements methods for directly executing `COPY FROM/TO STDOUT` on a [`PgPool`][crate::PgPool]. +/// +/// This is a replacement for the inherent methods on `PgPool` which could not exist +/// once the Postgres driver was moved out into its own crate. +pub trait PgPoolCopyExt { + /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres. + /// This is a more efficient way to import data into Postgres as compared to + /// `INSERT` but requires one of a few specific data formats (text/CSV/binary). + /// + /// A single connection will be checked out for the duration. + /// + /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is + /// returned. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// + /// + /// ### Note + /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection + /// will return an error the next time it is used. + fn copy_in_raw<'a>( + &'a self, + statement: &'a str, + ) -> BoxFuture<'a, Result>>>; + + /// Issue a `COPY TO STDOUT` statement and begin streaming data + /// from Postgres. This is a more efficient way to export data from Postgres but + /// arrives in chunks of one of a few data formats (text/CSV/binary). + /// + /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command, + /// an error is returned. + /// + /// Note that once this process has begun, unless you read the stream to completion, + /// it can only be canceled in two ways: + /// + /// 1. by closing the connection, or: + /// 2. by using another connection to kill the server process that is sending the data as shown + /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598). + /// + /// If you don't read the stream to completion, the next time the connection is used it will + /// need to read and discard all the remaining queued data, which could take some time. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// + fn copy_out_raw<'a>( + &'a self, + statement: &'a str, + ) -> BoxFuture<'a, Result>>>; +} + +impl PgPoolCopyExt for Pool { + fn copy_in_raw<'a>( + &'a self, + statement: &'a str, + ) -> BoxFuture<'a, Result>>> { + Box::pin(async { PgCopyIn::begin(self.acquire().await?, statement).await }) + } + + fn copy_out_raw<'a>( + &'a self, + statement: &'a str, + ) -> BoxFuture<'a, Result>>> { + Box::pin(async { pg_begin_copy_out(self.acquire().await?, statement).await }) + } +} + +/// A connection in streaming `COPY FROM STDIN` mode. +/// +/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw]. +/// +/// ### Note +/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection +/// will return an error the next time it is used. +#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"] +pub struct PgCopyIn> { + conn: Option, + response: CopyResponseData, +} + +impl> PgCopyIn { + async fn begin(mut conn: C, statement: &str) -> Result { + conn.wait_until_ready().await?; + conn.stream.send(Query(statement)).await?; + + let response = match conn.stream.recv_expect::().await { + Ok(res) => res.0, + Err(e) => { + conn.stream.recv().await?; + return Err(e); + } + }; + + Ok(PgCopyIn { + conn: Some(conn), + response, + }) + } + + /// Returns `true` if Postgres is expecting data in text or CSV format. + pub fn is_textual(&self) -> bool { + self.response.format == 0 + } + + /// Returns the number of columns expected in the input. + pub fn num_columns(&self) -> usize { + assert_eq!( + self.response.num_columns.unsigned_abs() as usize, + self.response.format_codes.len(), + "num_columns does not match format_codes.len()" + ); + self.response.format_codes.len() + } + + /// Check if a column is expecting data in text format (`true`) or binary format (`false`). + /// + /// ### Panics + /// If `column` is out of range according to [`.num_columns()`][Self::num_columns]. + pub fn column_is_textual(&self, column: usize) -> bool { + self.response.format_codes[column] == 0 + } + + /// Send a chunk of `COPY` data. + /// + /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead. + pub async fn send(&mut self, data: impl Deref) -> Result<&mut Self> { + self.conn + .as_deref_mut() + .expect("send_data: conn taken") + .stream + .send(CopyData(data)) + .await?; + + Ok(self) + } + + /// Copy data directly from `source` to the database without requiring an intermediate buffer. + /// + /// `source` will be read to the end. + /// + /// ### Note: Completion Step Required + /// You must still call either [Self::finish] or [Self::abort] to complete the process. + /// + /// ### Note: Runtime Features + /// This method uses the `AsyncRead` trait which is re-exported from either Tokio or `async-std` + /// depending on which runtime feature is used. + /// + /// The runtime features _used_ to be mutually exclusive, but are no longer. + /// If both `runtime-async-std` and `runtime-tokio` features are enabled, the Tokio version + /// takes precedent. + pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> { + let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken"); + loop { + let buf = conn.stream.write_buffer_mut(); + + // Write the CopyData format code and reserve space for the length. + // This may end up sending an empty `CopyData` packet if, after this point, + // we get canceled or read 0 bytes, but that should be fine. + buf.put_slice(b"d\0\0\0\x04"); + + let read = buf.read_from(&mut source).await?; + + if read == 0 { + break; + } + + // Write the length + let read32 = u32::try_from(read) + .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?; + + (&mut buf.get_mut()[1..]).put_u32(read32 + 4); + + conn.stream.flush().await?; + } + + Ok(self) + } + + /// Signal that the `COPY` process should be aborted and any data received should be discarded. + /// + /// The given message can be used for indicating the reason for the abort in the database logs. + /// + /// The server is expected to respond with an error, so only _unexpected_ errors are returned. + pub async fn abort(mut self, msg: impl Into) -> Result<()> { + let mut conn = self + .conn + .take() + .expect("PgCopyIn::fail_with: conn taken illegally"); + + conn.stream.send(CopyFail::new(msg)).await?; + + match conn.stream.recv().await { + Ok(msg) => Err(err_protocol!( + "fail_with: expected ErrorResponse, got: {:?}", + msg.format + )), + Err(Error::Database(e)) => { + match e.code() { + Some(Cow::Borrowed("57014")) => { + // postgres abort received error code + conn.stream.recv_expect::().await?; + Ok(()) + } + _ => Err(Error::Database(e)), + } + } + Err(e) => Err(e), + } + } + + /// Signal that the `COPY` process is complete. + /// + /// The number of rows affected is returned. + pub async fn finish(mut self) -> Result { + let mut conn = self + .conn + .take() + .expect("CopyWriter::finish: conn taken illegally"); + + conn.stream.send(CopyDone).await?; + let cc: CommandComplete = match conn.stream.recv_expect().await { + Ok(cc) => cc, + Err(e) => { + conn.stream.recv().await?; + return Err(e); + } + }; + + conn.stream.recv_expect::().await?; + + Ok(cc.rows_affected()) + } +} + +impl> Drop for PgCopyIn { + fn drop(&mut self) { + if let Some(mut conn) = self.conn.take() { + conn.stream + .write_msg(CopyFail::new( + "PgCopyIn dropped without calling finish() or fail()", + )) + .expect("BUG: PgCopyIn abort message should not be too large"); + } + } +} + +async fn pg_begin_copy_out<'c, C: DerefMut + Send + 'c>( + mut conn: C, + statement: &str, +) -> Result>> { + conn.wait_until_ready().await?; + conn.stream.send(Query(statement)).await?; + + let _: CopyOutResponse = conn.stream.recv_expect().await?; + + let stream: TryAsyncStream<'c, Bytes> = try_stream! { + loop { + match conn.stream.recv().await { + Err(e) => { + conn.stream.recv_expect::().await?; + return Err(e); + }, + Ok(msg) => match msg.format { + BackendMessageFormat::CopyData => r#yield!(msg.decode::>()?.0), + BackendMessageFormat::CopyDone => { + let _ = msg.decode::()?; + conn.stream.recv_expect::().await?; + conn.stream.recv_expect::().await?; + return Ok(()) + }, + _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format)) + } + } + } + }; + + Ok(Box::pin(stream)) +} diff --git a/patches/sqlx-postgres/src/database.rs b/patches/sqlx-postgres/src/database.rs new file mode 100644 index 000000000..876e29589 --- /dev/null +++ b/patches/sqlx-postgres/src/database.rs @@ -0,0 +1,40 @@ +use crate::arguments::PgArgumentBuffer; +use crate::value::{PgValue, PgValueRef}; +use crate::{ + PgArguments, PgColumn, PgConnection, PgQueryResult, PgRow, PgStatement, PgTransactionManager, + PgTypeInfo, +}; + +pub(crate) use sqlx_core::database::{Database, HasStatementCache}; + +/// PostgreSQL database driver. +#[derive(Debug)] +pub struct Postgres; + +impl Database for Postgres { + type Connection = PgConnection; + + type TransactionManager = PgTransactionManager; + + type Row = PgRow; + + type QueryResult = PgQueryResult; + + type Column = PgColumn; + + type TypeInfo = PgTypeInfo; + + type Value = PgValue; + type ValueRef<'r> = PgValueRef<'r>; + + type Arguments<'q> = PgArguments; + type ArgumentBuffer<'q> = PgArgumentBuffer; + + type Statement<'q> = PgStatement<'q>; + + const NAME: &'static str = "PostgreSQL"; + + const URL_SCHEMES: &'static [&'static str] = &["postgres", "postgresql"]; +} + +impl HasStatementCache for Postgres {} diff --git a/patches/sqlx-postgres/src/error.rs b/patches/sqlx-postgres/src/error.rs new file mode 100644 index 000000000..db8bcc8a1 --- /dev/null +++ b/patches/sqlx-postgres/src/error.rs @@ -0,0 +1,242 @@ +use std::error::Error as StdError; +use std::fmt::{self, Debug, Display, Formatter}; + +use atoi::atoi; +use smallvec::alloc::borrow::Cow; +use sqlx_core::bytes::Bytes; +pub(crate) use sqlx_core::error::*; + +use crate::message::{BackendMessage, BackendMessageFormat, Notice, PgSeverity}; + +/// An error returned from the PostgreSQL database. +pub struct PgDatabaseError(pub(crate) Notice); + +// Error message fields are documented: +// https://www.postgresql.org/docs/current/protocol-error-fields.html + +impl PgDatabaseError { + #[inline] + pub fn severity(&self) -> PgSeverity { + self.0.severity() + } + + /// The [SQLSTATE](https://www.postgresql.org/docs/current/errcodes-appendix.html) code for + /// this error. + #[inline] + pub fn code(&self) -> &str { + self.0.code() + } + + /// The primary human-readable error message. This should be accurate but + /// terse (typically one line). + #[inline] + pub fn message(&self) -> &str { + self.0.message() + } + + /// An optional secondary error message carrying more detail about the problem. + /// Might run to multiple lines. + #[inline] + pub fn detail(&self) -> Option<&str> { + self.0.get(b'D') + } + + /// An optional suggestion what to do about the problem. This is intended to differ from + /// `detail` in that it offers advice (potentially inappropriate) rather than hard facts. + /// Might run to multiple lines. + #[inline] + pub fn hint(&self) -> Option<&str> { + self.0.get(b'H') + } + + /// Indicates an error cursor position as an index into the original query string; or, + /// a position into an internally generated query. + #[inline] + pub fn position(&self) -> Option> { + self.0 + .get_raw(b'P') + .and_then(atoi) + .map(PgErrorPosition::Original) + .or_else(|| { + let position = self.0.get_raw(b'p').and_then(atoi)?; + let query = self.0.get(b'q')?; + + Some(PgErrorPosition::Internal { position, query }) + }) + } + + /// An indication of the context in which the error occurred. Presently this includes a call + /// stack traceback of active procedural language functions and internally-generated queries. + /// The trace is one entry per line, most recent first. + pub fn r#where(&self) -> Option<&str> { + self.0.get(b'W') + } + + /// If this error is with a specific database object, the + /// name of the schema containing that object, if any. + pub fn schema(&self) -> Option<&str> { + self.0.get(b's') + } + + /// If this error is with a specific table, the name of the table. + pub fn table(&self) -> Option<&str> { + self.0.get(b't') + } + + /// If the error is with a specific table column, the name of the column. + pub fn column(&self) -> Option<&str> { + self.0.get(b'c') + } + + /// If the error is with a specific data type, the name of the data type. + pub fn data_type(&self) -> Option<&str> { + self.0.get(b'd') + } + + /// If the error is with a specific constraint, the name of the constraint. + /// For this purpose, indexes are constraints, even if they weren't created + /// with constraint syntax. + pub fn constraint(&self) -> Option<&str> { + self.0.get(b'n') + } + + /// The file name of the source-code location where this error was reported. + pub fn file(&self) -> Option<&str> { + self.0.get(b'F') + } + + /// The line number of the source-code location where this error was reported. + pub fn line(&self) -> Option { + self.0.get_raw(b'L').and_then(atoi) + } + + /// The name of the source-code routine reporting this error. + pub fn routine(&self) -> Option<&str> { + self.0.get(b'R') + } +} + +#[derive(Debug, Eq, PartialEq)] +pub enum PgErrorPosition<'a> { + /// A position (in characters) into the original query. + Original(usize), + + /// A position into the internally-generated query. + Internal { + /// The position in characters. + position: usize, + + /// The text of a failed internally-generated command. This could be, for example, + /// the SQL query issued by a PL/pgSQL function. + query: &'a str, + }, +} + +impl Debug for PgDatabaseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("PgDatabaseError") + .field("severity", &self.severity()) + .field("code", &self.code()) + .field("message", &self.message()) + .field("detail", &self.detail()) + .field("hint", &self.hint()) + .field("position", &self.position()) + .field("where", &self.r#where()) + .field("schema", &self.schema()) + .field("table", &self.table()) + .field("column", &self.column()) + .field("data_type", &self.data_type()) + .field("constraint", &self.constraint()) + .field("file", &self.file()) + .field("line", &self.line()) + .field("routine", &self.routine()) + .finish() + } +} + +impl Display for PgDatabaseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str(self.message()) + } +} + +impl StdError for PgDatabaseError {} + +impl DatabaseError for PgDatabaseError { + fn message(&self) -> &str { + self.message() + } + + fn code(&self) -> Option> { + Some(Cow::Borrowed(self.code())) + } + + #[doc(hidden)] + fn as_error(&self) -> &(dyn StdError + Send + Sync + 'static) { + self + } + + #[doc(hidden)] + fn as_error_mut(&mut self) -> &mut (dyn StdError + Send + Sync + 'static) { + self + } + + #[doc(hidden)] + fn into_error(self: Box) -> BoxDynError { + self + } + + fn is_transient_in_connect_phase(&self) -> bool { + // https://www.postgresql.org/docs/current/errcodes-appendix.html + [ + // too_many_connections + // This may be returned if we just un-gracefully closed a connection, + // give the database a chance to notice it and clean it up. + "53300", + // cannot_connect_now + // Returned if the database is still starting up. + "57P03", + ] + .contains(&self.code()) + } + + fn constraint(&self) -> Option<&str> { + self.constraint() + } + + fn table(&self) -> Option<&str> { + self.table() + } + + fn kind(&self) -> ErrorKind { + match self.code() { + error_codes::UNIQUE_VIOLATION => ErrorKind::UniqueViolation, + error_codes::FOREIGN_KEY_VIOLATION => ErrorKind::ForeignKeyViolation, + error_codes::NOT_NULL_VIOLATION => ErrorKind::NotNullViolation, + error_codes::CHECK_VIOLATION => ErrorKind::CheckViolation, + _ => ErrorKind::Other, + } + } +} + +// ErrorResponse is the same structure as NoticeResponse but a different format code. +impl BackendMessage for PgDatabaseError { + const FORMAT: BackendMessageFormat = BackendMessageFormat::ErrorResponse; + + #[inline(always)] + fn decode_body(buf: Bytes) -> std::result::Result { + Ok(Self(Notice::decode_body(buf)?)) + } +} + +/// For reference: +pub(crate) mod error_codes { + /// Caused when a unique or primary key is violated. + pub const UNIQUE_VIOLATION: &str = "23505"; + /// Caused when a foreign key is violated. + pub const FOREIGN_KEY_VIOLATION: &str = "23503"; + /// Caused when a column marked as NOT NULL received a null value. + pub const NOT_NULL_VIOLATION: &str = "23502"; + /// Caused when a check constraint is violated. + pub const CHECK_VIOLATION: &str = "23514"; +} diff --git a/patches/sqlx-postgres/src/io/buf_mut.rs b/patches/sqlx-postgres/src/io/buf_mut.rs new file mode 100644 index 000000000..0fe3809b5 --- /dev/null +++ b/patches/sqlx-postgres/src/io/buf_mut.rs @@ -0,0 +1,58 @@ +use crate::io::{PortalId, StatementId}; + +pub trait PgBufMutExt { + fn put_length_prefixed(&mut self, f: F) -> Result<(), crate::Error> + where + F: FnOnce(&mut Vec) -> Result<(), crate::Error>; + + fn put_statement_name(&mut self, id: StatementId); + + fn put_portal_name(&mut self, id: PortalId); +} + +impl PgBufMutExt for Vec { + // writes a length-prefixed message, this is used when encoding nearly all messages as postgres + // wants us to send the length of the often-variable-sized messages up front + fn put_length_prefixed(&mut self, write_contents: F) -> Result<(), crate::Error> + where + F: FnOnce(&mut Vec) -> Result<(), crate::Error>, + { + // reserve space to write the prefixed length + let offset = self.len(); + self.extend(&[0; 4]); + + // write the main body of the message + let write_result = write_contents(self); + + let size_result = write_result.and_then(|_| { + let size = self.len() - offset; + i32::try_from(size) + .map_err(|_| err_protocol!("message size out of range for protocol: {size}")) + }); + + match size_result { + Ok(size) => { + // now calculate the size of what we wrote and set the length value + self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes()); + Ok(()) + } + Err(e) => { + // Put the buffer back to where it was. + self.truncate(offset); + Err(e) + } + } + } + + // writes a statement name by ID + #[inline] + fn put_statement_name(&mut self, id: StatementId) { + id.put_name_with_nul(self); + } + + // writes a portal name by ID + #[inline] + fn put_portal_name(&mut self, id: PortalId) { + id.put_name_with_nul(self); + } +} diff --git a/patches/sqlx-postgres/src/io/mod.rs b/patches/sqlx-postgres/src/io/mod.rs new file mode 100644 index 000000000..72f2a978c --- /dev/null +++ b/patches/sqlx-postgres/src/io/mod.rs @@ -0,0 +1,156 @@ +mod buf_mut; + +pub use buf_mut::PgBufMutExt; +use std::fmt; +use std::fmt::{Display, Formatter}; +use std::num::{NonZeroU32, Saturating}; + +pub(crate) use sqlx_core::io::*; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(crate) struct StatementId(IdInner); + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(crate) struct PortalId(IdInner); + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +struct IdInner(Option); + +pub(crate) struct DisplayId { + prefix: &'static str, + id: NonZeroU32, +} + +impl StatementId { + #[allow(dead_code)] + pub const UNNAMED: Self = Self(IdInner::UNNAMED); + + pub const NAMED_START: Self = Self(IdInner::NAMED_START); + + #[cfg(test)] + pub const TEST_VAL: Self = Self(IdInner::TEST_VAL); + + const NAME_PREFIX: &'static str = "sqlx_s_"; + + pub fn next(&self) -> Self { + Self(self.0.next()) + } + + pub fn name_len(&self) -> Saturating { + self.0.name_len(Self::NAME_PREFIX) + } + + /// Get a type to format this statement ID with [`Display`]. + /// + /// Returns `None` if this is the unnamed statement. + #[inline(always)] + pub fn display(&self) -> Option { + self.0.display(Self::NAME_PREFIX) + } + + pub fn put_name_with_nul(&self, buf: &mut Vec) { + self.0.put_name_with_nul(Self::NAME_PREFIX, buf) + } +} + +impl Display for DisplayId { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}{}", self.prefix, self.id) + } +} + +#[allow(dead_code)] +impl PortalId { + // None selects the unnamed portal + pub const UNNAMED: Self = PortalId(IdInner::UNNAMED); + + pub const NAMED_START: Self = PortalId(IdInner::NAMED_START); + + #[cfg(test)] + pub const TEST_VAL: Self = Self(IdInner::TEST_VAL); + + const NAME_PREFIX: &'static str = "sqlx_p_"; + + /// If ID represents a named portal, return the next ID, wrapping on overflow. + /// + /// If this ID represents the unnamed portal, return the same. + pub fn next(&self) -> Self { + Self(self.0.next()) + } + + /// Calculate the number of bytes that will be written by [`Self::put_name_with_nul()`]. + pub fn name_len(&self) -> Saturating { + self.0.name_len(Self::NAME_PREFIX) + } + + pub fn put_name_with_nul(&self, buf: &mut Vec) { + self.0.put_name_with_nul(Self::NAME_PREFIX, buf) + } +} + +impl IdInner { + const UNNAMED: Self = Self(None); + + const NAMED_START: Self = Self(Some(NonZeroU32::MIN)); + + #[cfg(test)] + pub const TEST_VAL: Self = Self(NonZeroU32::new(1234567890)); + + #[inline(always)] + fn next(&self) -> Self { + Self( + self.0 + .map(|id| id.checked_add(1).unwrap_or(NonZeroU32::MIN)), + ) + } + + #[inline(always)] + fn display(&self, prefix: &'static str) -> Option { + self.0.map(|id| DisplayId { prefix, id }) + } + + #[inline(always)] + fn name_len(&self, name_prefix: &str) -> Saturating { + let mut len = Saturating(0); + + if let Some(id) = self.0 { + len += name_prefix.len(); + // estimate the length of the ID in decimal + // `.ilog10()` can't panic since the value is never zero + len += id.get().ilog10() as usize; + // add one to compensate for `ilog10()` rounding down. + len += 1; + } + + // count the NUL terminator + len += 1; + + len + } + + #[inline(always)] + fn put_name_with_nul(&self, name_prefix: &str, buf: &mut Vec) { + if let Some(id) = self.0 { + buf.extend_from_slice(name_prefix.as_bytes()); + buf.extend_from_slice(itoa::Buffer::new().format(id.get()).as_bytes()); + } + + buf.push(0); + } +} + +#[test] +fn statement_id_display_matches_encoding() { + const EXPECTED_STR: &str = "sqlx_s_1234567890"; + const EXPECTED_BYTES: &[u8] = b"sqlx_s_1234567890\0"; + + let mut bytes = Vec::new(); + + StatementId::TEST_VAL.put_name_with_nul(&mut bytes); + + assert_eq!(bytes, EXPECTED_BYTES); + + let str = StatementId::TEST_VAL.display().unwrap().to_string(); + + assert_eq!(str, EXPECTED_STR); +} diff --git a/patches/sqlx-postgres/src/lib.rs b/patches/sqlx-postgres/src/lib.rs new file mode 100644 index 000000000..c50f53067 --- /dev/null +++ b/patches/sqlx-postgres/src/lib.rs @@ -0,0 +1,77 @@ +//! **PostgreSQL** database driver. + +#[macro_use] +extern crate sqlx_core; + +use crate::executor::Executor; + +mod advisory_lock; +mod arguments; +mod column; +mod connection; +mod copy; +mod database; +mod error; +mod io; +mod listener; +mod message; +mod options; +mod query_result; +mod row; +mod statement; +mod transaction; +mod type_checking; +mod type_info; +pub mod types; +mod value; + +#[cfg(feature = "any")] +// We are hiding the any module with its AnyConnectionBackend trait +// so that IDEs don't show it in the autocompletion list +// and end users don't accidentally use it. This can result in +// nested transactions not behaving as expected. +// For more information, see https://github.com/launchbadge/sqlx/pull/3254#issuecomment-2144043823 +#[doc(hidden)] +pub mod any; + +#[cfg(feature = "migrate")] +mod migrate; + +#[cfg(feature = "migrate")] +mod testing; + +pub(crate) use sqlx_core::driver_prelude::*; + +pub use advisory_lock::{PgAdvisoryLock, PgAdvisoryLockGuard, PgAdvisoryLockKey}; +pub use arguments::{PgArgumentBuffer, PgArguments}; +pub use column::PgColumn; +pub use connection::PgConnection; +pub use copy::{PgCopyIn, PgPoolCopyExt}; +pub use database::Postgres; +pub use error::{PgDatabaseError, PgErrorPosition}; +pub use listener::{PgListener, PgNotification}; +pub use message::PgSeverity; +pub use options::{PgConnectOptions, PgSslMode}; +pub use query_result::PgQueryResult; +pub use row::PgRow; +pub use statement::PgStatement; +pub use transaction::PgTransactionManager; +pub use type_info::{PgTypeInfo, PgTypeKind}; +pub use types::PgHasArrayType; +pub use value::{PgValue, PgValueFormat, PgValueRef}; + +/// An alias for [`Pool`][crate::pool::Pool], specialized for Postgres. +pub type PgPool = crate::pool::Pool; + +/// An alias for [`PoolOptions`][crate::pool::PoolOptions], specialized for Postgres. +pub type PgPoolOptions = crate::pool::PoolOptions; + +/// An alias for [`Executor<'_, Database = Postgres>`][Executor]. +pub trait PgExecutor<'c>: Executor<'c, Database = Postgres> {} +impl<'c, T: Executor<'c, Database = Postgres>> PgExecutor<'c> for T {} + +impl_into_arguments_for_arguments!(PgArguments); +impl_acquire!(Postgres, PgConnection); +impl_column_index_for_row!(PgRow); +impl_column_index_for_statement!(PgStatement); +impl_encode_for_option!(Postgres); diff --git a/patches/sqlx-postgres/src/listener.rs b/patches/sqlx-postgres/src/listener.rs new file mode 100644 index 000000000..43bd3c8ff --- /dev/null +++ b/patches/sqlx-postgres/src/listener.rs @@ -0,0 +1,457 @@ +use std::fmt::{self, Debug}; +use std::io; +use std::str::from_utf8; + +use futures_channel::mpsc; +use futures_core::future::BoxFuture; +use futures_core::stream::{BoxStream, Stream}; +use futures_util::{FutureExt, StreamExt, TryStreamExt}; +use sqlx_core::Either; + +use crate::describe::Describe; +use crate::error::Error; +use crate::executor::{Execute, Executor}; +use crate::message::{BackendMessageFormat, Notification}; +use crate::pool::PoolOptions; +use crate::pool::{Pool, PoolConnection}; +use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres}; + +/// A stream of asynchronous notifications from Postgres. +/// +/// This listener will auto-reconnect. If the active +/// connection being used ever dies, this listener will detect that event, create a +/// new connection, will re-subscribe to all of the originally specified channels, and will resume +/// operations as normal. +pub struct PgListener { + pool: Pool, + connection: Option>, + buffer_rx: mpsc::UnboundedReceiver, + buffer_tx: Option>, + channels: Vec, + ignore_close_event: bool, +} + +/// An asynchronous notification from Postgres. +pub struct PgNotification(Notification); + +impl PgListener { + pub async fn connect(url: &str) -> Result { + // Create a pool of 1 without timeouts (as they don't apply here) + // We only use the pool to handle re-connections + let pool = PoolOptions::::new() + .max_connections(1) + .max_lifetime(None) + .idle_timeout(None) + .connect(url) + .await?; + + let mut this = Self::connect_with(&pool).await?; + // We don't need to handle close events + this.ignore_close_event = true; + + Ok(this) + } + + pub async fn connect_with(pool: &Pool) -> Result { + // Pull out an initial connection + let mut connection = pool.acquire().await?; + + // Setup a notification buffer + let (sender, receiver) = mpsc::unbounded(); + connection.stream.notifications = Some(sender); + + Ok(Self { + pool: pool.clone(), + connection: Some(connection), + buffer_rx: receiver, + buffer_tx: None, + channels: Vec::new(), + ignore_close_event: false, + }) + } + + /// Set whether or not to ignore [`Pool::close_event()`]. Defaults to `false`. + /// + /// By default, when [`Pool::close()`] is called on the pool this listener is using + /// while [`Self::recv()`] or [`Self::try_recv()`] are waiting for a message, the wait is + /// cancelled and `Err(PoolClosed)` is returned. + /// + /// This is because `Pool::close()` will wait until _all_ connections are returned and closed, + /// including the one being used by this listener. + /// + /// Otherwise, `pool.close().await` would have to wait until `PgListener` encountered a + /// need to acquire a new connection (timeout, error, etc.) and dropped the one it was + /// currently holding, at which point `.recv()` or `.try_recv()` would return `Err(PoolClosed)` + /// on the attempt to acquire a new connection anyway. + /// + /// However, if you want `PgListener` to ignore the close event and continue waiting for a + /// message as long as it can, set this to `true`. + /// + /// Does nothing if this was constructed with [`PgListener::connect()`], as that creates an + /// internal pool just for the new instance of `PgListener` which cannot be closed manually. + pub fn ignore_pool_close_event(&mut self, val: bool) { + self.ignore_close_event = val; + } + + /// Starts listening for notifications on a channel. + /// The channel name is quoted here to ensure case sensitivity. + pub async fn listen(&mut self, channel: &str) -> Result<(), Error> { + self.connection() + .await? + .execute(&*format!(r#"LISTEN "{}""#, ident(channel))) + .await?; + + self.channels.push(channel.to_owned()); + + Ok(()) + } + + /// Starts listening for notifications on all channels. + pub async fn listen_all( + &mut self, + channels: impl IntoIterator, + ) -> Result<(), Error> { + let beg = self.channels.len(); + self.channels.extend(channels.into_iter().map(|s| s.into())); + + let query = build_listen_all_query(&self.channels[beg..]); + self.connection().await?.execute(&*query).await?; + + Ok(()) + } + + /// Stops listening for notifications on a channel. + /// The channel name is quoted here to ensure case sensitivity. + pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> { + // use RAW connection and do NOT re-connect automatically, since this is not required for + // UNLISTEN (we've disconnected anyways) + if let Some(connection) = self.connection.as_mut() { + connection + .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel))) + .await?; + } + + if let Some(pos) = self.channels.iter().position(|s| s == channel) { + self.channels.remove(pos); + } + + Ok(()) + } + + /// Stops listening for notifications on all channels. + pub async fn unlisten_all(&mut self) -> Result<(), Error> { + // use RAW connection and do NOT re-connect automatically, since this is not required for + // UNLISTEN (we've disconnected anyways) + if let Some(connection) = self.connection.as_mut() { + connection.execute("UNLISTEN *").await?; + } + + self.channels.clear(); + + Ok(()) + } + + #[inline] + async fn connect_if_needed(&mut self) -> Result<(), Error> { + if self.connection.is_none() { + let mut connection = self.pool.acquire().await?; + connection.stream.notifications = self.buffer_tx.take(); + + connection + .execute(&*build_listen_all_query(&self.channels)) + .await?; + + self.connection = Some(connection); + } + + Ok(()) + } + + #[inline] + async fn connection(&mut self) -> Result<&mut PgConnection, Error> { + // Ensure we have an active connection to work with. + self.connect_if_needed().await?; + + Ok(self.connection.as_mut().unwrap()) + } + + /// Receives the next notification available from any of the subscribed channels. + /// + /// If the connection to PostgreSQL is lost, it is automatically reconnected on the next + /// call to `recv()`, and should be entirely transparent (as long as it was just an + /// intermittent network failure or long-lived connection reaper). + /// + /// As notifications are transient, any received while the connection was lost, will not + /// be returned. If you'd prefer the reconnection to be explicit and have a chance to + /// do something before, please see [`try_recv`](Self::try_recv). + /// + /// # Example + /// + /// ```rust,no_run + /// # use sqlx::postgres::PgListener; + /// # + /// # sqlx::__rt::test_block_on(async move { + /// let mut listener = PgListener::connect("postgres:// ...").await?; + /// loop { + /// // ask for next notification, re-connecting (transparently) if needed + /// let notification = listener.recv().await?; + /// + /// // handle notification, do something interesting + /// } + /// # Result::<(), sqlx::Error>::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn recv(&mut self) -> Result { + loop { + if let Some(notification) = self.try_recv().await? { + return Ok(notification); + } + } + } + + /// Receives the next notification available from any of the subscribed channels. + /// + /// If the connection to PostgreSQL is lost, `None` is returned, and the connection is + /// reconnected on the next call to `try_recv()`. + /// + /// # Example + /// + /// ```rust,no_run + /// # use sqlx::postgres::PgListener; + /// # + /// # sqlx::__rt::test_block_on(async move { + /// # let mut listener = PgListener::connect("postgres:// ...").await?; + /// loop { + /// // start handling notifications, connecting if needed + /// while let Some(notification) = listener.try_recv().await? { + /// // handle notification + /// } + /// + /// // connection lost, do something interesting + /// } + /// # Result::<(), sqlx::Error>::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn try_recv(&mut self) -> Result, Error> { + // Flush the buffer first, if anything + // This would only fill up if this listener is used as a connection + if let Ok(Some(notification)) = self.buffer_rx.try_next() { + return Ok(Some(PgNotification(notification))); + } + + // Fetch our `CloseEvent` listener, if applicable. + let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event()); + + loop { + let next_message = self.connection().await?.stream.recv_unchecked(); + + let res = if let Some(ref mut close_event) = close_event { + // cancels the wait and returns `Err(PoolClosed)` if the pool is closed + // before `next_message` returns, or if the pool was already closed + close_event.do_until(next_message).await? + } else { + next_message.await + }; + + let message = match res { + Ok(message) => message, + + // The connection is dead, ensure that it is dropped, + // update self state, and loop to try again. + Err(Error::Io(err)) + if (err.kind() == io::ErrorKind::ConnectionAborted + || err.kind() == io::ErrorKind::UnexpectedEof) => + { + self.buffer_tx = self.connection().await?.stream.notifications.take(); + self.connection = None; + + // lost connection + return Ok(None); + } + + // Forward other errors + Err(error) => { + return Err(error); + } + }; + + match message.format { + // We've received an async notification, return it. + BackendMessageFormat::NotificationResponse => { + return Ok(Some(PgNotification(message.decode()?))); + } + + // Mark the connection as ready for another query + BackendMessageFormat::ReadyForQuery => { + self.connection().await?.pending_ready_for_query_count -= 1; + } + + // Ignore unexpected messages + _ => {} + } + } + } + + /// Consume this listener, returning a `Stream` of notifications. + /// + /// The backing connection will be automatically reconnected should it be lost. + /// + /// This has the same potential drawbacks as [`recv`](PgListener::recv). + /// + pub fn into_stream(mut self) -> impl Stream> + Unpin { + Box::pin(try_stream! { + loop { + r#yield!(self.recv().await?); + } + }) + } +} + +impl Drop for PgListener { + fn drop(&mut self) { + if let Some(mut conn) = self.connection.take() { + let fut = async move { + let _ = conn.execute("UNLISTEN *").await; + + // inline the drop handler from `PoolConnection` so it doesn't try to spawn another task + // otherwise, it may trigger a panic if this task is dropped because the runtime is going away: + // https://github.com/launchbadge/sqlx/issues/1389 + conn.return_to_pool().await; + }; + + // Unregister any listeners before returning the connection to the pool. + crate::rt::spawn(fut); + } + } +} + +impl<'c> Executor<'c> for &'c mut PgListener { + type Database = Postgres; + + fn fetch_many<'e, 'q, E>( + self, + query: E, + ) -> BoxStream<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + 'q: 'e, + E: 'q, + { + futures_util::stream::once(async move { + // need some basic type annotation to help the compiler a bit + let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query)); + res + }) + .try_flatten() + .boxed() + } + + fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + 'q: 'e, + E: 'q, + { + async move { self.connection().await?.fetch_optional(query).await }.boxed() + } + + fn prepare_with<'e, 'q: 'e>( + self, + query: &'q str, + parameters: &'e [PgTypeInfo], + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + { + async move { + self.connection() + .await? + .prepare_with(query, parameters) + .await + } + .boxed() + } + + #[doc(hidden)] + fn describe<'e, 'q: 'e>( + self, + query: &'q str, + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + { + async move { self.connection().await?.describe(query).await }.boxed() + } +} + +impl PgNotification { + /// The process ID of the notifying backend process. + #[inline] + pub fn process_id(&self) -> u32 { + self.0.process_id + } + + /// The channel that the notify has been raised on. This can be thought + /// of as the message topic. + #[inline] + pub fn channel(&self) -> &str { + from_utf8(&self.0.channel).unwrap() + } + + /// The payload of the notification. An empty payload is received as an + /// empty string. + #[inline] + pub fn payload(&self) -> &str { + from_utf8(&self.0.payload).unwrap() + } +} + +impl Debug for PgListener { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PgListener").finish() + } +} + +impl Debug for PgNotification { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PgNotification") + .field("process_id", &self.process_id()) + .field("channel", &self.channel()) + .field("payload", &self.payload()) + .finish() + } +} + +fn ident(mut name: &str) -> String { + // If the input string contains a NUL byte, we should truncate the + // identifier. + if let Some(index) = name.find('\0') { + name = &name[..index]; + } + + // Any double quotes must be escaped + name.replace('"', "\"\"") +} + +fn build_listen_all_query(channels: impl IntoIterator>) -> String { + channels.into_iter().fold(String::new(), |mut acc, chan| { + acc.push_str(r#"LISTEN ""#); + acc.push_str(&ident(chan.as_ref())); + acc.push_str(r#"";"#); + acc + }) +} + +#[test] +fn test_build_listen_all_query_with_single_channel() { + let output = build_listen_all_query(&["test"]); + assert_eq!(output.as_str(), r#"LISTEN "test";"#); +} + +#[test] +fn test_build_listen_all_query_with_multiple_channels() { + let output = build_listen_all_query(&["channel.0", "channel.1"]); + assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#); +} diff --git a/patches/sqlx-postgres/src/message/authentication.rs b/patches/sqlx-postgres/src/message/authentication.rs new file mode 100644 index 000000000..3a3cf7ff6 --- /dev/null +++ b/patches/sqlx-postgres/src/message/authentication.rs @@ -0,0 +1,193 @@ +use std::str::from_utf8; + +use memchr::memchr; +use sqlx_core::bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::ProtocolDecode; + +use crate::message::{BackendMessage, BackendMessageFormat}; +use base64::prelude::{Engine as _, BASE64_STANDARD}; +// On startup, the server sends an appropriate authentication request message, +// to which the frontend must reply with an appropriate authentication +// response message (such as a password). + +// For all authentication methods except GSSAPI, SSPI and SASL, there is at +// most one request and one response. In some methods, no response at all is +// needed from the frontend, and so no authentication request occurs. + +// For GSSAPI, SSPI and SASL, multiple exchanges of packets may +// be needed to complete the authentication. + +// +// + +#[derive(Debug)] +pub enum Authentication { + /// The authentication exchange is successfully completed. + Ok, + + /// The frontend must now send a [PasswordMessage] containing the + /// password in clear-text form. + CleartextPassword, + + /// The frontend must now send a [PasswordMessage] containing the + /// password (with user name) encrypted via MD5, then encrypted + /// again using the 4-byte random salt. + Md5Password(AuthenticationMd5Password), + + /// The frontend must now initiate a SASL negotiation, + /// using one of the SASL mechanisms listed in the message. + /// + /// The frontend will send a [SaslInitialResponse] with the name + /// of the selected mechanism, and the first part of the SASL + /// data stream in response to this. + /// + /// If further messages are needed, the server will + /// respond with [Authentication::SaslContinue]. + Sasl(AuthenticationSasl), + + /// This message contains challenge data from the previous step of SASL negotiation. + /// + /// The frontend must respond with a [SaslResponse] message. + SaslContinue(AuthenticationSaslContinue), + + /// SASL authentication has completed with additional mechanism-specific + /// data for the client. + /// + /// The server will next send [Authentication::Ok] to + /// indicate successful authentication. + SaslFinal(AuthenticationSaslFinal), +} + +impl BackendMessage for Authentication { + const FORMAT: BackendMessageFormat = BackendMessageFormat::Authentication; + + fn decode_body(mut buf: Bytes) -> Result { + Ok(match buf.get_u32() { + 0 => Authentication::Ok, + + 3 => Authentication::CleartextPassword, + + 5 => { + let mut salt = [0; 4]; + buf.copy_to_slice(&mut salt); + + Authentication::Md5Password(AuthenticationMd5Password { salt }) + } + + 10 => Authentication::Sasl(AuthenticationSasl(buf)), + 11 => Authentication::SaslContinue(AuthenticationSaslContinue::decode(buf)?), + 12 => Authentication::SaslFinal(AuthenticationSaslFinal::decode(buf)?), + + ty => { + return Err(err_protocol!("unknown authentication method: {}", ty)); + } + }) + } +} + +/// Body of [Authentication::Md5Password]. +#[derive(Debug)] +pub struct AuthenticationMd5Password { + pub salt: [u8; 4], +} + +/// Body of [Authentication::Sasl]. +#[derive(Debug)] +pub struct AuthenticationSasl(Bytes); + +impl AuthenticationSasl { + #[inline] + pub fn mechanisms(&self) -> SaslMechanisms<'_> { + SaslMechanisms(&self.0) + } +} + +/// An iterator over the SASL authentication mechanisms provided by the server. +pub struct SaslMechanisms<'a>(&'a [u8]); + +impl<'a> Iterator for SaslMechanisms<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option { + if !self.0.is_empty() && self.0[0] == b'\0' { + return None; + } + + let mechanism = memchr(b'\0', self.0).and_then(|nul| from_utf8(&self.0[..nul]).ok())?; + + self.0 = &self.0[(mechanism.len() + 1)..]; + + Some(mechanism) + } +} + +#[derive(Debug)] +pub struct AuthenticationSaslContinue { + pub salt: Vec, + pub iterations: u32, + pub nonce: String, + pub message: String, +} + +impl ProtocolDecode<'_> for AuthenticationSaslContinue { + fn decode_with(buf: Bytes, _: ()) -> Result { + let mut iterations: u32 = 4096; + let mut salt = Vec::new(); + let mut nonce = Bytes::new(); + + // [Example] + // r=/z+giZiTxAH7r8sNAeHr7cvpqV3uo7G/bJBIJO3pjVM7t3ng,s=4UV68bIkC8f9/X8xH7aPhg==,i=4096 + + for item in buf.split(|b| *b == b',') { + let key = item[0]; + let value = &item[2..]; + + match key { + b'r' => { + nonce = buf.slice_ref(value); + } + + b'i' => { + iterations = atoi::atoi(value).unwrap_or(4096); + } + + b's' => { + salt = BASE64_STANDARD.decode(value).map_err(Error::protocol)?; + } + + _ => {} + } + } + + Ok(Self { + iterations, + salt, + nonce: from_utf8(&nonce).map_err(Error::protocol)?.to_owned(), + message: from_utf8(&buf).map_err(Error::protocol)?.to_owned(), + }) + } +} + +#[derive(Debug)] +pub struct AuthenticationSaslFinal { + pub verifier: Vec, +} + +impl ProtocolDecode<'_> for AuthenticationSaslFinal { + fn decode_with(buf: Bytes, _: ()) -> Result { + let mut verifier = Vec::new(); + + for item in buf.split(|b| *b == b',') { + let key = item[0]; + let value = &item[2..]; + + if let b'v' = key { + verifier = BASE64_STANDARD.decode(value).map_err(Error::protocol)?; + } + } + + Ok(Self { verifier }) + } +} diff --git a/patches/sqlx-postgres/src/message/backend_key_data.rs b/patches/sqlx-postgres/src/message/backend_key_data.rs new file mode 100644 index 000000000..f2dc2f232 --- /dev/null +++ b/patches/sqlx-postgres/src/message/backend_key_data.rs @@ -0,0 +1,50 @@ +use byteorder::{BigEndian, ByteOrder}; +use sqlx_core::bytes::Bytes; + +use crate::error::Error; +use crate::message::{BackendMessage, BackendMessageFormat}; + +/// Contains cancellation key data. The frontend must save these values if it +/// wishes to be able to issue `CancelRequest` messages later. +#[derive(Debug)] +pub struct BackendKeyData { + /// The process ID of this database. + pub process_id: u32, + + /// The secret key of this database. + pub secret_key: u32, +} + +impl BackendMessage for BackendKeyData { + const FORMAT: BackendMessageFormat = BackendMessageFormat::BackendKeyData; + + fn decode_body(buf: Bytes) -> Result { + let process_id = BigEndian::read_u32(&buf); + let secret_key = BigEndian::read_u32(&buf[4..]); + + Ok(Self { + process_id, + secret_key, + }) + } +} + +#[test] +fn test_decode_backend_key_data() { + const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; + + let m = BackendKeyData::decode_body(DATA.into()).unwrap(); + + assert_eq!(m.process_id, 10182); + assert_eq!(m.secret_key, 2303903019); +} + +#[cfg(all(test, not(debug_assertions)))] +#[bench] +fn bench_decode_backend_key_data(b: &mut test::Bencher) { + const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; + + b.iter(|| { + BackendKeyData::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap(); + }); +} diff --git a/patches/sqlx-postgres/src/message/bind.rs b/patches/sqlx-postgres/src/message/bind.rs new file mode 100644 index 000000000..83631fea5 --- /dev/null +++ b/patches/sqlx-postgres/src/message/bind.rs @@ -0,0 +1,95 @@ +use crate::io::{PgBufMutExt, PortalId, StatementId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use crate::PgValueFormat; +use std::num::Saturating; + +#[derive(Debug)] +pub struct Bind<'a> { + /// The ID of the destination portal (`PortalId::UNNAMED` selects the unnamed portal). + pub portal: PortalId, + + /// The id of the source prepared statement. + pub statement: StatementId, + + /// The parameter format codes. Each must presently be zero (text) or one (binary). + /// + /// There can be zero to indicate that there are no parameters or that the parameters all use the + /// default format (text); or one, in which case the specified format code is applied to all + /// parameters; or it can equal the actual number of parameters. + pub formats: &'a [PgValueFormat], + + /// The number of parameters. + /// + /// May be different from `formats.len()` + pub num_params: i16, + + /// The value of each parameter, in the indicated format. + pub params: &'a [u8], + + /// The result-column format codes. Each must presently be zero (text) or one (binary). + /// + /// There can be zero to indicate that there are no result columns or that the + /// result columns should all use the default format (text); or one, in which + /// case the specified format code is applied to all result columns (if any); + /// or it can equal the actual number of result columns of the query. + pub result_formats: &'a [PgValueFormat], +} + +impl FrontendMessage for Bind<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Bind; + + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + size += self.portal.name_len(); + size += self.statement.name_len(); + + // Parameter formats and length prefix + size += 2; + size += self.formats.len(); + + // `num_params` + size += 2; + + size += self.params.len(); + + // Result formats and length prefix + size += 2; + size += self.result_formats.len(); + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), crate::Error> { + buf.put_portal_name(self.portal); + + buf.put_statement_name(self.statement); + + let formats_len = i16::try_from(self.formats.len()).map_err(|_| { + err_protocol!("too many parameter format codes ({})", self.formats.len()) + })?; + + buf.extend(formats_len.to_be_bytes()); + + for &format in self.formats { + buf.extend((format as i16).to_be_bytes()); + } + + buf.extend(self.num_params.to_be_bytes()); + + buf.extend(self.params); + + let result_formats_len = i16::try_from(self.formats.len()) + .map_err(|_| err_protocol!("too many result format codes ({})", self.formats.len()))?; + + buf.extend(result_formats_len.to_be_bytes()); + + for &format in self.result_formats { + buf.extend((format as i16).to_be_bytes()); + } + + Ok(()) + } +} + +// TODO: Unit Test Bind +// TODO: Benchmark Bind diff --git a/patches/sqlx-postgres/src/message/close.rs b/patches/sqlx-postgres/src/message/close.rs new file mode 100644 index 000000000..172f244c1 --- /dev/null +++ b/patches/sqlx-postgres/src/message/close.rs @@ -0,0 +1,45 @@ +use crate::io::{PgBufMutExt, PortalId, StatementId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use std::num::Saturating; + +const CLOSE_PORTAL: u8 = b'P'; +const CLOSE_STATEMENT: u8 = b'S'; + +#[derive(Debug)] +#[allow(dead_code)] +pub enum Close { + Statement(StatementId), + Portal(PortalId), +} + +impl FrontendMessage for Close { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Close; + + fn body_size_hint(&self) -> Saturating { + // Either `CLOSE_PORTAL` or `CLOSE_STATEMENT` + let mut size = Saturating(1); + + match self { + Close::Statement(id) => size += id.name_len(), + Close::Portal(id) => size += id.name_len(), + } + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), crate::Error> { + match self { + Close::Statement(id) => { + buf.push(CLOSE_STATEMENT); + buf.put_statement_name(*id); + } + + Close::Portal(id) => { + buf.push(CLOSE_PORTAL); + buf.put_portal_name(*id); + } + } + + Ok(()) + } +} diff --git a/patches/sqlx-postgres/src/message/command_complete.rs b/patches/sqlx-postgres/src/message/command_complete.rs new file mode 100644 index 000000000..eb33c512d --- /dev/null +++ b/patches/sqlx-postgres/src/message/command_complete.rs @@ -0,0 +1,82 @@ +use atoi::atoi; +use memchr::memrchr; +use sqlx_core::bytes::Bytes; + +use crate::error::Error; +use crate::message::{BackendMessage, BackendMessageFormat}; + +#[derive(Debug)] +pub struct CommandComplete { + /// The command tag. This is usually a single word that identifies which SQL command + /// was completed. + tag: Bytes, +} + +impl BackendMessage for CommandComplete { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CommandComplete; + + fn decode_body(bytes: Bytes) -> Result { + Ok(CommandComplete { tag: bytes }) + } +} + +impl CommandComplete { + /// Returns the number of rows affected. + /// If the command does not return rows (e.g., "CREATE TABLE"), returns 0. + pub fn rows_affected(&self) -> u64 { + // Look backwards for the first SPACE + memrchr(b' ', &self.tag) + // This is either a word or the number of rows affected + .and_then(|i| atoi(&self.tag[(i + 1)..])) + .unwrap_or(0) + } +} + +#[test] +fn test_decode_command_complete_for_insert() { + const DATA: &[u8] = b"INSERT 0 1214\0"; + + let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap(); + + assert_eq!(cc.rows_affected(), 1214); +} + +#[test] +fn test_decode_command_complete_for_begin() { + const DATA: &[u8] = b"BEGIN\0"; + + let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap(); + + assert_eq!(cc.rows_affected(), 0); +} + +#[test] +fn test_decode_command_complete_for_update() { + const DATA: &[u8] = b"UPDATE 5\0"; + + let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap(); + + assert_eq!(cc.rows_affected(), 5); +} + +#[cfg(all(test, not(debug_assertions)))] +#[bench] +fn bench_decode_command_complete(b: &mut test::Bencher) { + const DATA: &[u8] = b"INSERT 0 1214\0"; + + b.iter(|| { + let _ = CommandComplete::decode_body(test::black_box(Bytes::from_static(DATA))); + }); +} + +#[cfg(all(test, not(debug_assertions)))] +#[bench] +fn bench_decode_command_complete_rows_affected(b: &mut test::Bencher) { + const DATA: &[u8] = b"INSERT 0 1214\0"; + + let data = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap(); + + b.iter(|| { + let _rows = test::black_box(&data).rows_affected(); + }); +} diff --git a/patches/sqlx-postgres/src/message/copy.rs b/patches/sqlx-postgres/src/message/copy.rs new file mode 100644 index 000000000..837d849a0 --- /dev/null +++ b/patches/sqlx-postgres/src/message/copy.rs @@ -0,0 +1,141 @@ +use crate::error::Result; +use crate::io::BufMutExt; +use crate::message::{ + BackendMessage, BackendMessageFormat, FrontendMessage, FrontendMessageFormat, +}; +use sqlx_core::bytes::{Buf, Bytes}; +use sqlx_core::Error; +use std::num::Saturating; +use std::ops::Deref; + +/// The same structure is sent for both `CopyInResponse` and `CopyOutResponse` +pub struct CopyResponseData { + pub format: i8, + pub num_columns: i16, + pub format_codes: Vec, +} + +pub struct CopyInResponse(pub CopyResponseData); + +#[allow(dead_code)] +pub struct CopyOutResponse(pub CopyResponseData); + +pub struct CopyData(pub B); + +pub struct CopyFail { + pub message: String, +} + +pub struct CopyDone; + +impl CopyResponseData { + #[inline] + fn decode(mut buf: Bytes) -> Result { + let format = buf.get_i8(); + let num_columns = buf.get_i16(); + + let format_codes = (0..num_columns).map(|_| buf.get_i16()).collect(); + + Ok(CopyResponseData { + format, + num_columns, + format_codes, + }) + } +} + +impl BackendMessage for CopyInResponse { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyInResponse; + + #[inline(always)] + fn decode_body(buf: Bytes) -> std::result::Result { + Ok(Self(CopyResponseData::decode(buf)?)) + } +} + +impl BackendMessage for CopyOutResponse { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyOutResponse; + + #[inline(always)] + fn decode_body(buf: Bytes) -> std::result::Result { + Ok(Self(CopyResponseData::decode(buf)?)) + } +} + +impl BackendMessage for CopyData { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyData; + + #[inline(always)] + fn decode_body(buf: Bytes) -> std::result::Result { + Ok(Self(buf)) + } +} + +impl> FrontendMessage for CopyData { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyData; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(self.0.len()) + } + + #[inline(always)] + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + buf.extend_from_slice(&self.0); + Ok(()) + } +} + +impl FrontendMessage for CopyFail { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyFail; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(self.message.len()) + } + + #[inline(always)] + fn encode_body(&self, buf: &mut Vec) -> std::result::Result<(), Error> { + buf.put_str_nul(&self.message); + Ok(()) + } +} + +impl CopyFail { + #[inline(always)] + pub fn new(msg: impl Into) -> CopyFail { + CopyFail { + message: msg.into(), + } + } +} + +impl FrontendMessage for CopyDone { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyDone; + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(0) + } + + #[inline(always)] + fn encode_body(&self, _buf: &mut Vec) -> std::result::Result<(), Error> { + Ok(()) + } +} + +impl BackendMessage for CopyDone { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyDone; + + #[inline(always)] + fn decode_body(bytes: Bytes) -> std::result::Result { + if !bytes.is_empty() { + // Not fatal but may indicate a protocol change + tracing::debug!( + "Postgres backend returned non-empty message for CopyDone: \"{}\"", + bytes.escape_ascii() + ) + } + + Ok(CopyDone) + } +} diff --git a/patches/sqlx-postgres/src/message/data_row.rs b/patches/sqlx-postgres/src/message/data_row.rs new file mode 100644 index 000000000..ae9d0d9b2 --- /dev/null +++ b/patches/sqlx-postgres/src/message/data_row.rs @@ -0,0 +1,138 @@ +use byteorder::{BigEndian, ByteOrder}; +use sqlx_core::bytes::Bytes; +use std::ops::Range; + +use crate::error::Error; +use crate::message::{BackendMessage, BackendMessageFormat}; + +/// A row of data from the database. +#[derive(Debug)] +pub struct DataRow { + pub(crate) storage: Bytes, + + /// Ranges into the stored row data. + /// This uses `u32` instead of usize to reduce the size of this type. Values cannot be larger + /// than `i32` in postgres. + pub(crate) values: Vec>>, +} + +impl DataRow { + #[inline] + pub(crate) fn get(&self, index: usize) -> Option<&'_ [u8]> { + self.values[index] + .as_ref() + .map(|col| &self.storage[(col.start as usize)..(col.end as usize)]) + } +} + +impl BackendMessage for DataRow { + const FORMAT: BackendMessageFormat = BackendMessageFormat::DataRow; + + fn decode_body(buf: Bytes) -> Result { + if buf.len() < 2 { + return Err(err_protocol!( + "expected at least 2 bytes, got {}", + buf.len() + )); + } + + let cnt = BigEndian::read_u16(&buf) as usize; + + let mut values = Vec::with_capacity(cnt); + let mut offset: u32 = 2; + + for _ in 0..cnt { + let value_start = offset + .checked_add(4) + .ok_or_else(|| err_protocol!("next value start out of range (offset: {offset})"))?; + + // widen both to a larger type for a safe comparison + if (buf.len() as u64) < (value_start as u64) { + return Err(err_protocol!( + "expected 4 bytes at offset {offset}, got {}", + (value_start as u64) - (buf.len() as u64) + )); + } + + // Length of the column value, in bytes (this count does not include itself). + // Can be zero. As a special case, -1 indicates a NULL column value. + // No value bytes follow in the NULL case. + // + // we know `offset` is within range of `buf.len()` from the above check + #[allow(clippy::cast_possible_truncation)] + let length = BigEndian::read_i32(&buf[(offset as usize)..]); + + if let Ok(length) = u32::try_from(length) { + let value_end = value_start.checked_add(length).ok_or_else(|| { + err_protocol!("value_start + length out of range ({offset} + {length})") + })?; + + values.push(Some(value_start..value_end)); + offset = value_end; + } else { + // Negative values signify NULL + values.push(None); + // `value_start` is actually the next value now. + offset = value_start; + } + } + + Ok(Self { + storage: buf, + values, + }) + } +} + +#[test] +fn test_decode_data_row() { + const DATA: &[u8] = b"\ + \x00\x08\ + \xff\xff\xff\xff\ + \x00\x00\x00\x04\ + \x00\x00\x00\n\ + \xff\xff\xff\xff\ + \x00\x00\x00\x04\ + \x00\x00\x00\x14\ + \xff\xff\xff\xff\ + \x00\x00\x00\x04\ + \x00\x00\x00(\ + \xff\xff\xff\xff\ + \x00\x00\x00\x04\ + \x00\x00\x00P"; + + let row = DataRow::decode_body(DATA.into()).unwrap(); + + assert_eq!(row.values.len(), 8); + + assert!(row.get(0).is_none()); + assert_eq!(row.get(1).unwrap(), &[0_u8, 0, 0, 10][..]); + assert!(row.get(2).is_none()); + assert_eq!(row.get(3).unwrap(), &[0_u8, 0, 0, 20][..]); + assert!(row.get(4).is_none()); + assert_eq!(row.get(5).unwrap(), &[0_u8, 0, 0, 40][..]); + assert!(row.get(6).is_none()); + assert_eq!(row.get(7).unwrap(), &[0_u8, 0, 0, 80][..]); +} + +#[cfg(all(test, not(debug_assertions)))] +#[bench] +fn bench_data_row_get(b: &mut test::Bencher) { + const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P"; + + let row = DataRow::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap(); + + b.iter(|| { + let _value = test::black_box(&row).get(3); + }); +} + +#[cfg(all(test, not(debug_assertions)))] +#[bench] +fn bench_decode_data_row(b: &mut test::Bencher) { + const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P"; + + b.iter(|| { + let _ = DataRow::decode_body(test::black_box(Bytes::from_static(DATA))); + }); +} diff --git a/patches/sqlx-postgres/src/message/describe.rs b/patches/sqlx-postgres/src/message/describe.rs new file mode 100644 index 000000000..d6ea7e89c --- /dev/null +++ b/patches/sqlx-postgres/src/message/describe.rs @@ -0,0 +1,103 @@ +use crate::io::{PgBufMutExt, PortalId, StatementId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; + +const DESCRIBE_PORTAL: u8 = b'P'; +const DESCRIBE_STATEMENT: u8 = b'S'; + +/// Note: will emit both a RowDescription and a ParameterDescription message +#[derive(Debug)] +#[allow(dead_code)] +pub enum Describe { + Statement(StatementId), + Portal(PortalId), +} + +impl FrontendMessage for Describe { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Describe; + + fn body_size_hint(&self) -> Saturating { + // Either `DESCRIBE_PORTAL` or `DESCRIBE_STATEMENT` + let mut size = Saturating(1); + + match self { + Describe::Statement(id) => size += id.name_len(), + Describe::Portal(id) => size += id.name_len(), + } + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + match self { + // #[likely] + Describe::Statement(id) => { + buf.push(DESCRIBE_STATEMENT); + buf.put_statement_name(*id); + } + + Describe::Portal(id) => { + buf.push(DESCRIBE_PORTAL); + buf.put_portal_name(*id); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::message::FrontendMessage; + + use super::{Describe, PortalId, StatementId}; + + #[test] + fn test_encode_describe_portal() { + const EXPECTED: &[u8] = b"D\0\0\0\x17Psqlx_p_1234567890\0"; + + let mut buf = Vec::new(); + let m = Describe::Portal(PortalId::TEST_VAL); + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } + + #[test] + fn test_encode_describe_unnamed_portal() { + const EXPECTED: &[u8] = b"D\0\0\0\x06P\0"; + + let mut buf = Vec::new(); + let m = Describe::Portal(PortalId::UNNAMED); + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } + + #[test] + fn test_encode_describe_statement() { + const EXPECTED: &[u8] = b"D\0\0\0\x17Ssqlx_s_1234567890\0"; + + let mut buf = Vec::new(); + let m = Describe::Statement(StatementId::TEST_VAL); + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } + + #[test] + fn test_encode_describe_unnamed_statement() { + const EXPECTED: &[u8] = b"D\0\0\0\x06S\0"; + + let mut buf = Vec::new(); + let m = Describe::Statement(StatementId::UNNAMED); + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } +} diff --git a/patches/sqlx-postgres/src/message/execute.rs b/patches/sqlx-postgres/src/message/execute.rs new file mode 100644 index 000000000..f82b7884b --- /dev/null +++ b/patches/sqlx-postgres/src/message/execute.rs @@ -0,0 +1,73 @@ +use std::num::Saturating; + +use sqlx_core::Error; + +use crate::io::{PgBufMutExt, PortalId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; + +pub struct Execute { + /// The id of the portal to execute. + pub portal: PortalId, + + /// Maximum number of rows to return, if portal contains a query + /// that returns rows (ignored otherwise). Zero denotes “no limit”. + pub limit: u32, +} + +impl FrontendMessage for Execute { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Execute; + + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + + size += self.portal.name_len(); + size += 2; // limit + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + buf.put_portal_name(self.portal); + buf.extend(&self.limit.to_be_bytes()); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::io::PortalId; + use crate::message::FrontendMessage; + + use super::Execute; + + #[test] + fn test_encode_execute_named_portal() { + const EXPECTED: &[u8] = b"E\0\0\0\x1Asqlx_p_1234567890\0\0\0\0\x02"; + + let mut buf = Vec::new(); + let m = Execute { + portal: PortalId::TEST_VAL, + limit: 2, + }; + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } + + #[test] + fn test_encode_execute_unnamed_portal() { + const EXPECTED: &[u8] = b"E\0\0\0\x09\0\x49\x96\x02\xD2"; + + let mut buf = Vec::new(); + let m = Execute { + portal: PortalId::UNNAMED, + limit: 1234567890, + }; + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } +} diff --git a/patches/sqlx-postgres/src/message/flush.rs b/patches/sqlx-postgres/src/message/flush.rs new file mode 100644 index 000000000..d1dfabbfa --- /dev/null +++ b/patches/sqlx-postgres/src/message/flush.rs @@ -0,0 +1,25 @@ +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; + +/// The Flush message does not cause any specific output to be generated, +/// but forces the backend to deliver any data pending in its output buffers. +/// +/// A Flush must be sent after any extended-query command except Sync, if the +/// frontend wishes to examine the results of that command before issuing more commands. +#[derive(Debug)] +pub struct Flush; + +impl FrontendMessage for Flush { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Flush; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(0) + } + + #[inline(always)] + fn encode_body(&self, _buf: &mut Vec) -> Result<(), Error> { + Ok(()) + } +} diff --git a/patches/sqlx-postgres/src/message/mod.rs b/patches/sqlx-postgres/src/message/mod.rs new file mode 100644 index 000000000..e62f9bebb --- /dev/null +++ b/patches/sqlx-postgres/src/message/mod.rs @@ -0,0 +1,229 @@ +use sqlx_core::bytes::Bytes; +use std::num::Saturating; + +use crate::error::Error; +use crate::io::PgBufMutExt; + +mod authentication; +mod backend_key_data; +mod bind; +mod close; +mod command_complete; +mod copy; +mod data_row; +mod describe; +mod execute; +mod flush; +mod notification; +mod parameter_description; +mod parameter_status; +mod parse; +mod parse_complete; +mod password; +mod query; +mod ready_for_query; +mod response; +mod row_description; +mod sasl; +mod ssl_request; +mod startup; +mod sync; +mod terminate; + +pub use authentication::{Authentication, AuthenticationSasl}; +pub use backend_key_data::BackendKeyData; +pub use bind::Bind; +pub use close::Close; +pub use command_complete::CommandComplete; +pub use copy::{CopyData, CopyDone, CopyFail, CopyInResponse, CopyOutResponse, CopyResponseData}; +pub use data_row::DataRow; +pub use describe::Describe; +pub use execute::Execute; +#[allow(unused_imports)] +pub use flush::Flush; +pub use notification::Notification; +pub use parameter_description::ParameterDescription; +pub use parameter_status::ParameterStatus; +pub use parse::Parse; +pub use parse_complete::ParseComplete; +pub use password::Password; +pub use query::Query; +pub use ready_for_query::{ReadyForQuery, TransactionStatus}; +pub use response::{Notice, PgSeverity}; +pub use row_description::RowDescription; +pub use sasl::{SaslInitialResponse, SaslResponse}; +use sqlx_core::io::ProtocolEncode; +pub use ssl_request::SslRequest; +pub use startup::Startup; +pub use sync::Sync; +pub use terminate::Terminate; + +// Note: we can't use the same enum for both frontend and backend message formats +// because there are duplicated format codes between them. +// +// For example, `Close` (frontend) and `CommandComplete` (backend) both use format code `C`. +// +#[derive(Debug, PartialOrd, PartialEq)] +#[repr(u8)] +pub enum FrontendMessageFormat { + Bind = b'B', + Close = b'C', + CopyData = b'd', + CopyDone = b'c', + CopyFail = b'f', + Describe = b'D', + Execute = b'E', + Flush = b'H', + Parse = b'P', + /// This message format is polymorphic. It's used for: + /// + /// * Plain password responses + /// * MD5 password responses + /// * SASL responses + /// * GSSAPI/SSPI responses + PasswordPolymorphic = b'p', + Query = b'Q', + Sync = b'S', + Terminate = b'X', +} + +#[derive(Debug, PartialOrd, PartialEq)] +#[repr(u8)] +pub enum BackendMessageFormat { + Authentication, + BackendKeyData, + BindComplete, + CloseComplete, + CommandComplete, + CopyData, + CopyDone, + CopyInResponse, + CopyOutResponse, + DataRow, + EmptyQueryResponse, + ErrorResponse, + NoData, + NoticeResponse, + NotificationResponse, + ParameterDescription, + ParameterStatus, + ParseComplete, + PortalSuspended, + ReadyForQuery, + RowDescription, +} + +#[derive(Debug)] +pub struct ReceivedMessage { + pub format: BackendMessageFormat, + pub contents: Bytes, +} + +impl ReceivedMessage { + #[inline] + pub fn decode(self) -> Result + where + T: BackendMessage, + { + if T::FORMAT != self.format { + return Err(err_protocol!( + "Postgres protocol error: expected {:?}, got {:?}", + T::FORMAT, + self.format + )); + } + + T::decode_body(self.contents).map_err(|e| match e { + Error::Protocol(s) => { + err_protocol!("Postgres protocol error (reading {:?}): {s}", self.format) + } + other => other, + }) + } +} + +impl BackendMessageFormat { + pub fn try_from_u8(v: u8) -> Result { + // https://www.postgresql.org/docs/current/protocol-message-formats.html + + Ok(match v { + b'1' => BackendMessageFormat::ParseComplete, + b'2' => BackendMessageFormat::BindComplete, + b'3' => BackendMessageFormat::CloseComplete, + b'C' => BackendMessageFormat::CommandComplete, + b'd' => BackendMessageFormat::CopyData, + b'c' => BackendMessageFormat::CopyDone, + b'G' => BackendMessageFormat::CopyInResponse, + b'H' => BackendMessageFormat::CopyOutResponse, + b'D' => BackendMessageFormat::DataRow, + b'E' => BackendMessageFormat::ErrorResponse, + b'I' => BackendMessageFormat::EmptyQueryResponse, + b'A' => BackendMessageFormat::NotificationResponse, + b'K' => BackendMessageFormat::BackendKeyData, + b'N' => BackendMessageFormat::NoticeResponse, + b'R' => BackendMessageFormat::Authentication, + b'S' => BackendMessageFormat::ParameterStatus, + b'T' => BackendMessageFormat::RowDescription, + b'Z' => BackendMessageFormat::ReadyForQuery, + b'n' => BackendMessageFormat::NoData, + b's' => BackendMessageFormat::PortalSuspended, + b't' => BackendMessageFormat::ParameterDescription, + + _ => return Err(err_protocol!("unknown message type: {:?}", v as char)), + }) + } +} + +pub(crate) trait FrontendMessage: Sized { + /// The format prefix of this message. + const FORMAT: FrontendMessageFormat; + + /// Return the amount of space, in bytes, to reserve in the buffer passed to [`Self::encode_body()`]. + fn body_size_hint(&self) -> Saturating; + + /// Encode this type as a Frontend message in the Postgres protocol. + /// + /// The implementation should *not* include `Self::FORMAT` or the length prefix. + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error>; + + #[inline(always)] + #[cfg_attr(not(test), allow(dead_code))] + fn encode_msg(self, buf: &mut Vec) -> Result<(), Error> { + EncodeMessage(self).encode(buf) + } +} + +pub(crate) trait BackendMessage: Sized { + /// The expected message format. + /// + /// + const FORMAT: BackendMessageFormat; + + /// Decode this type from a Backend message in the Postgres protocol. + /// + /// The format code and length prefix have already been read and are not at the start of `bytes`. + fn decode_body(buf: Bytes) -> Result; +} + +pub struct EncodeMessage(pub F); + +impl ProtocolEncode<'_, ()> for EncodeMessage { + fn encode_with(&self, buf: &mut Vec, _context: ()) -> Result<(), Error> { + let mut size_hint = self.0.body_size_hint(); + // plus format code and length prefix + size_hint += 5; + + // don't panic if `size_hint` is ridiculous + buf.try_reserve(size_hint.0).map_err(|e| { + err_protocol!( + "Postgres protocol: error allocating {} bytes for encoding message {:?}: {e}", + size_hint.0, + F::FORMAT, + ) + })?; + + buf.push(F::FORMAT as u8); + + buf.put_length_prefixed(|buf| self.0.encode_body(buf)) + } +} diff --git a/patches/sqlx-postgres/src/message/notification.rs b/patches/sqlx-postgres/src/message/notification.rs new file mode 100644 index 000000000..7bf029839 --- /dev/null +++ b/patches/sqlx-postgres/src/message/notification.rs @@ -0,0 +1,39 @@ +use sqlx_core::bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::BufExt; +use crate::message::{BackendMessage, BackendMessageFormat}; + +#[derive(Debug)] +pub struct Notification { + pub(crate) process_id: u32, + pub(crate) channel: Bytes, + pub(crate) payload: Bytes, +} + +impl BackendMessage for Notification { + const FORMAT: BackendMessageFormat = BackendMessageFormat::NotificationResponse; + + fn decode_body(mut buf: Bytes) -> Result { + let process_id = buf.get_u32(); + let channel = buf.get_bytes_nul()?; + let payload = buf.get_bytes_nul()?; + + Ok(Self { + process_id, + channel, + payload, + }) + } +} + +#[test] +fn test_decode_notification_response() { + const NOTIFICATION_RESPONSE: &[u8] = b"\x34\x20\x10\x02TEST-CHANNEL\0THIS IS A TEST\0"; + + let message = Notification::decode_body(Bytes::from(NOTIFICATION_RESPONSE)).unwrap(); + + assert_eq!(message.process_id, 0x34201002); + assert_eq!(&*message.channel, &b"TEST-CHANNEL"[..]); + assert_eq!(&*message.payload, &b"THIS IS A TEST"[..]); +} diff --git a/patches/sqlx-postgres/src/message/parameter_description.rs b/patches/sqlx-postgres/src/message/parameter_description.rs new file mode 100644 index 000000000..8aa361a8e --- /dev/null +++ b/patches/sqlx-postgres/src/message/parameter_description.rs @@ -0,0 +1,56 @@ +use smallvec::SmallVec; +use sqlx_core::bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::message::{BackendMessage, BackendMessageFormat}; +use crate::types::Oid; + +#[derive(Debug)] +pub struct ParameterDescription { + pub types: SmallVec<[Oid; 6]>, +} + +impl BackendMessage for ParameterDescription { + const FORMAT: BackendMessageFormat = BackendMessageFormat::ParameterDescription; + + fn decode_body(mut buf: Bytes) -> Result { + let cnt = buf.get_u16(); + let mut types = SmallVec::with_capacity(cnt as usize); + + for _ in 0..cnt { + types.push(Oid(buf.get_u32())); + } + + Ok(Self { types }) + } +} + +#[test] +fn test_decode_parameter_description() { + const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; + + let m = ParameterDescription::decode_body(DATA.into()).unwrap(); + + assert_eq!(m.types.len(), 2); + assert_eq!(m.types[0], Oid(0x0000_0000)); + assert_eq!(m.types[1], Oid(0x0000_0500)); +} + +#[test] +fn test_decode_empty_parameter_description() { + const DATA: &[u8] = b"\x00\x00"; + + let m = ParameterDescription::decode_body(DATA.into()).unwrap(); + + assert!(m.types.is_empty()); +} + +#[cfg(all(test, not(debug_assertions)))] +#[bench] +fn bench_decode_parameter_description(b: &mut test::Bencher) { + const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; + + b.iter(|| { + ParameterDescription::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap(); + }); +} diff --git a/patches/sqlx-postgres/src/message/parameter_status.rs b/patches/sqlx-postgres/src/message/parameter_status.rs new file mode 100644 index 000000000..d979d1895 --- /dev/null +++ b/patches/sqlx-postgres/src/message/parameter_status.rs @@ -0,0 +1,65 @@ +use sqlx_core::bytes::Bytes; + +use crate::error::Error; +use crate::io::BufExt; +use crate::message::{BackendMessage, BackendMessageFormat}; + +#[derive(Debug)] +pub struct ParameterStatus { + pub name: String, + pub value: String, +} + +impl BackendMessage for ParameterStatus { + const FORMAT: BackendMessageFormat = BackendMessageFormat::ParameterStatus; + + fn decode_body(mut buf: Bytes) -> Result { + let name = buf.get_str_nul()?; + let value = buf.get_str_nul()?; + + Ok(Self { name, value }) + } +} + +#[test] +fn test_decode_parameter_status() { + const DATA: &[u8] = b"client_encoding\x00UTF8\x00"; + + let m = ParameterStatus::decode_body(DATA.into()).unwrap(); + + assert_eq!(&m.name, "client_encoding"); + assert_eq!(&m.value, "UTF8") +} + +#[test] +fn test_decode_empty_parameter_status() { + const DATA: &[u8] = b"\x00\x00"; + + let m = ParameterStatus::decode_body(DATA.into()).unwrap(); + + assert!(m.name.is_empty()); + assert!(m.value.is_empty()); +} + +#[cfg(all(test, not(debug_assertions)))] +#[bench] +fn bench_decode_parameter_status(b: &mut test::Bencher) { + const DATA: &[u8] = b"client_encoding\x00UTF8\x00"; + + b.iter(|| { + ParameterStatus::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap(); + }); +} + +#[test] +fn test_decode_parameter_status_response() { + const PARAMETER_STATUS_RESPONSE: &[u8] = b"crdb_version\0CockroachDB CCL v21.1.0 (x86_64-unknown-linux-gnu, built 2021/05/17 13:49:40, go1.15.11)\0"; + + let message = ParameterStatus::decode_body(Bytes::from(PARAMETER_STATUS_RESPONSE)).unwrap(); + + assert_eq!(message.name, "crdb_version"); + assert_eq!( + message.value, + "CockroachDB CCL v21.1.0 (x86_64-unknown-linux-gnu, built 2021/05/17 13:49:40, go1.15.11)" + ); +} diff --git a/patches/sqlx-postgres/src/message/parse.rs b/patches/sqlx-postgres/src/message/parse.rs new file mode 100644 index 000000000..3e77c3024 --- /dev/null +++ b/patches/sqlx-postgres/src/message/parse.rs @@ -0,0 +1,77 @@ +use crate::io::BufMutExt; +use crate::io::{PgBufMutExt, StatementId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use crate::types::Oid; +use sqlx_core::Error; +use std::num::Saturating; + +#[derive(Debug)] +pub struct Parse<'a> { + /// The ID of the destination prepared statement. + pub statement: StatementId, + + /// The query string to be parsed. + pub query: &'a str, + + /// The parameter data types specified (could be zero). Note that this is not an + /// indication of the number of parameters that might appear in the query string, + /// only the number that the frontend wants to pre-specify types for. + pub param_types: &'a [Oid], +} + +impl FrontendMessage for Parse<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Parse; + + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + + size += self.statement.name_len(); + + size += self.query.len(); + size += 1; // NUL terminator + + size += 2; // param_types_len + + // `param_types` + size += self.param_types.len().saturating_mul(4); + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + buf.put_statement_name(self.statement); + + buf.put_str_nul(self.query); + + let param_types_len = i16::try_from(self.param_types.len()).map_err(|_| { + err_protocol!( + "param_types.len() too large for binary protocol: {}", + self.param_types.len() + ) + })?; + + buf.extend(param_types_len.to_be_bytes()); + + for &oid in self.param_types { + buf.extend(oid.0.to_be_bytes()); + } + + Ok(()) + } +} + +#[test] +fn test_encode_parse() { + const EXPECTED: &[u8] = b"P\0\0\0\x26sqlx_s_1234567890\0SELECT $1\0\0\x01\0\0\0\x19"; + + let mut buf = Vec::new(); + let m = Parse { + statement: StatementId::TEST_VAL, + query: "SELECT $1", + param_types: &[Oid(25)], + }; + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); +} diff --git a/patches/sqlx-postgres/src/message/parse_complete.rs b/patches/sqlx-postgres/src/message/parse_complete.rs new file mode 100644 index 000000000..3051f5ff9 --- /dev/null +++ b/patches/sqlx-postgres/src/message/parse_complete.rs @@ -0,0 +1,13 @@ +use crate::message::{BackendMessage, BackendMessageFormat}; +use sqlx_core::bytes::Bytes; +use sqlx_core::Error; + +pub struct ParseComplete; + +impl BackendMessage for ParseComplete { + const FORMAT: BackendMessageFormat = BackendMessageFormat::ParseComplete; + + fn decode_body(_bytes: Bytes) -> Result { + Ok(ParseComplete) + } +} diff --git a/patches/sqlx-postgres/src/message/password.rs b/patches/sqlx-postgres/src/message/password.rs new file mode 100644 index 000000000..4eaaeb15a --- /dev/null +++ b/patches/sqlx-postgres/src/message/password.rs @@ -0,0 +1,153 @@ +use crate::io::BufMutExt; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use md5::{Digest, Md5}; +use sqlx_core::Error; +use std::fmt::Write; +use std::num::Saturating; + +#[derive(Debug)] +pub enum Password<'a> { + Cleartext(&'a str), + + Md5 { + password: &'a str, + username: &'a str, + salt: [u8; 4], + }, +} + +impl FrontendMessage for Password<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + + match self { + Password::Cleartext(password) => { + // To avoid reporting the exact password length anywhere, + // we deliberately give a bad estimate. + // + // This shouldn't affect performance in the long run. + size += password + .len() + .saturating_add(1) // NUL terminator + .checked_next_power_of_two() + .unwrap_or(usize::MAX); + } + Password::Md5 { .. } => { + // "md5<32 hex chars>\0" + size += 36; + } + } + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + match self { + Password::Cleartext(password) => { + buf.put_str_nul(password); + } + + Password::Md5 { + username, + password, + salt, + } => { + // The actual `PasswordMessage` can be computed in SQL as + // `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`. + + // Keep in mind the md5() function returns its result as a hex string. + + let mut hasher = Md5::new(); + + hasher.update(password); + hasher.update(username); + + let mut output = String::with_capacity(35); + + let _ = write!(output, "{:x}", hasher.finalize_reset()); + + hasher.update(&output); + hasher.update(salt); + + output.clear(); + + let _ = write!(output, "md5{:x}", hasher.finalize()); + + buf.put_str_nul(&output); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::message::FrontendMessage; + + use super::Password; + + #[test] + fn test_encode_clear_password() { + const EXPECTED: &[u8] = b"p\0\0\0\rpassword\0"; + + let mut buf = Vec::new(); + let m = Password::Cleartext("password"); + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } + + #[test] + fn test_encode_md5_password() { + const EXPECTED: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0"; + + let mut buf = Vec::new(); + let m = Password::Md5 { + password: "password", + username: "root", + salt: [147, 24, 57, 152], + }; + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } + + #[cfg(all(test, not(debug_assertions)))] + #[bench] + fn bench_encode_clear_password(b: &mut test::Bencher) { + use test::black_box; + + let mut buf = Vec::with_capacity(128); + + b.iter(|| { + buf.clear(); + + black_box(Password::Cleartext("password")).encode_msg(&mut buf); + }); + } + + #[cfg(all(test, not(debug_assertions)))] + #[bench] + fn bench_encode_md5_password(b: &mut test::Bencher) { + use test::black_box; + + let mut buf = Vec::with_capacity(128); + + b.iter(|| { + buf.clear(); + + black_box(Password::Md5 { + password: "password", + username: "root", + salt: [147, 24, 57, 152], + }) + .encode_msg(&mut buf); + }); + } +} diff --git a/patches/sqlx-postgres/src/message/query.rs b/patches/sqlx-postgres/src/message/query.rs new file mode 100644 index 000000000..788d7808f --- /dev/null +++ b/patches/sqlx-postgres/src/message/query.rs @@ -0,0 +1,37 @@ +use crate::io::BufMutExt; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; + +#[derive(Debug)] +pub struct Query<'a>(pub &'a str); + +impl FrontendMessage for Query<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Query; + + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + + size += self.0.len(); + size += 1; // NUL terminator + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + buf.put_str_nul(self.0); + Ok(()) + } +} + +#[test] +fn test_encode_query() { + const EXPECTED: &[u8] = b"Q\0\0\0\x0DSELECT 1\0"; + + let mut buf = Vec::new(); + let m = Query("SELECT 1"); + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); +} diff --git a/patches/sqlx-postgres/src/message/ready_for_query.rs b/patches/sqlx-postgres/src/message/ready_for_query.rs new file mode 100644 index 000000000..a1f6761b8 --- /dev/null +++ b/patches/sqlx-postgres/src/message/ready_for_query.rs @@ -0,0 +1,56 @@ +use sqlx_core::bytes::Bytes; + +use crate::error::Error; +use crate::message::{BackendMessage, BackendMessageFormat}; + +#[derive(Debug)] +#[repr(u8)] +pub enum TransactionStatus { + /// Not in a transaction block. + Idle = b'I', + + /// In a transaction block. + Transaction = b'T', + + /// In a _failed_ transaction block. Queries will be rejected until block is ended. + Error = b'E', +} + +#[derive(Debug)] +pub struct ReadyForQuery { + pub transaction_status: TransactionStatus, +} + +impl BackendMessage for ReadyForQuery { + const FORMAT: BackendMessageFormat = BackendMessageFormat::ReadyForQuery; + + fn decode_body(buf: Bytes) -> Result { + let status = match buf[0] { + b'I' => TransactionStatus::Idle, + b'T' => TransactionStatus::Transaction, + b'E' => TransactionStatus::Error, + + status => { + return Err(err_protocol!( + "unknown transaction status: {:?}", + status as char + )); + } + }; + + Ok(Self { + transaction_status: status, + }) + } +} + +#[test] +fn test_decode_ready_for_query() -> Result<(), Error> { + const DATA: &[u8] = b"E"; + + let m = ReadyForQuery::decode_body(Bytes::from_static(DATA))?; + + assert!(matches!(m.transaction_status, TransactionStatus::Error)); + + Ok(()) +} diff --git a/patches/sqlx-postgres/src/message/response.rs b/patches/sqlx-postgres/src/message/response.rs new file mode 100644 index 000000000..d6e43e087 --- /dev/null +++ b/patches/sqlx-postgres/src/message/response.rs @@ -0,0 +1,272 @@ +use std::ops::Range; +use std::str::from_utf8; + +use memchr::memchr; + +use sqlx_core::bytes::Bytes; + +use crate::error::Error; +use crate::io::ProtocolDecode; +use crate::message::{BackendMessage, BackendMessageFormat}; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[repr(u8)] +pub enum PgSeverity { + Panic, + Fatal, + Error, + Warning, + Notice, + Debug, + Info, + Log, +} + +impl PgSeverity { + #[inline] + pub fn is_error(self) -> bool { + matches!(self, Self::Panic | Self::Fatal | Self::Error) + } +} + +impl TryFrom<&str> for PgSeverity { + type Error = Error; + + fn try_from(s: &str) -> Result { + let result = match s { + "PANIC" => PgSeverity::Panic, + "FATAL" => PgSeverity::Fatal, + "ERROR" => PgSeverity::Error, + "WARNING" => PgSeverity::Warning, + "NOTICE" => PgSeverity::Notice, + "DEBUG" => PgSeverity::Debug, + "INFO" => PgSeverity::Info, + "LOG" => PgSeverity::Log, + + severity => { + return Err(err_protocol!("unknown severity: {:?}", severity)); + } + }; + + Ok(result) + } +} + +#[derive(Debug)] +pub struct Notice { + storage: Bytes, + severity: PgSeverity, + message: Range, + code: Range, +} + +impl Notice { + #[inline] + pub fn severity(&self) -> PgSeverity { + self.severity + } + + #[inline] + pub fn code(&self) -> &str { + self.get_cached_str(self.code.clone()) + } + + #[inline] + pub fn message(&self) -> &str { + self.get_cached_str(self.message.clone()) + } + + // Field descriptions available here: + // https://www.postgresql.org/docs/current/protocol-error-fields.html + + #[inline] + pub fn get(&self, ty: u8) -> Option<&str> { + self.get_raw(ty).and_then(|v| from_utf8(v).ok()) + } + + pub fn get_raw(&self, ty: u8) -> Option<&[u8]> { + self.fields() + .filter(|(field, _)| *field == ty) + .map(|(_, range)| &self.storage[range]) + .next() + } +} + +impl Notice { + #[inline] + fn fields(&self) -> Fields<'_> { + Fields { + storage: &self.storage, + offset: 0, + } + } + + #[inline] + fn get_cached_str(&self, cache: Range) -> &str { + // unwrap: this cannot fail at this stage + from_utf8(&self.storage[cache]).unwrap() + } +} + +impl ProtocolDecode<'_> for Notice { + fn decode_with(buf: Bytes, _: ()) -> Result { + // In order to support PostgreSQL 9.5 and older we need to parse the localized S field. + // Newer versions additionally come with the V field that is guaranteed to be in English. + // We thus read both versions and prefer the unlocalized one if available. + const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log; + let mut severity_v = None; + let mut severity_s = None; + let mut message = 0..0; + let mut code = 0..0; + + // we cache the three always present fields + // this enables to keep the access time down for the fields most likely accessed + + let fields = Fields { + storage: &buf, + offset: 0, + }; + + for (field, v) in fields { + if !(message.is_empty() || code.is_empty()) { + // stop iterating when we have the 3 fields we were looking for + // we assume V (severity) was the first field as it should be + break; + } + + match field { + b'S' => { + severity_s = from_utf8(&buf[v.clone()]) + // If the error string is not UTF-8, we have no hope of interpreting it, + // localized or not. The `V` field would likely fail to parse as well. + .map_err(|_| notice_protocol_err())? + .try_into() + // If we couldn't parse the severity here, it might just be localized. + .ok(); + } + + b'V' => { + // Propagate errors here, because V is not localized and + // thus we are missing a possible variant. + severity_v = Some( + from_utf8(&buf[v.clone()]) + .map_err(|_| notice_protocol_err())? + .try_into()?, + ); + } + + b'M' => { + _ = from_utf8(&buf[v.clone()]).map_err(|_| notice_protocol_err())?; + message = v; + } + + b'C' => { + _ = from_utf8(&buf[v.clone()]).map_err(|_| notice_protocol_err())?; + code = v; + } + + // If more fields are added, make sure to check that they are valid UTF-8, + // otherwise the get_cached_str method will panic. + _ => {} + } + } + + Ok(Self { + severity: severity_v.or(severity_s).unwrap_or(DEFAULT_SEVERITY), + message, + code, + storage: buf, + }) + } +} + +impl BackendMessage for Notice { + const FORMAT: BackendMessageFormat = BackendMessageFormat::NoticeResponse; + + fn decode_body(buf: Bytes) -> Result { + // Keeping both impls for now + Self::decode_with(buf, ()) + } +} + +/// An iterator over each field in the Error (or Notice) response. +struct Fields<'a> { + storage: &'a [u8], + offset: usize, +} + +impl<'a> Iterator for Fields<'a> { + type Item = (u8, Range); + + fn next(&mut self) -> Option { + // The fields in the response body are sequentially stored as [tag][string], + // ending in a final, additional [nul] + + let ty = *self.storage.get(self.offset)?; + + if ty == 0 { + return None; + } + + // Consume the type byte + self.offset = self.offset.checked_add(1)?; + + let start = self.offset; + + let len = memchr(b'\0', self.storage.get(start..)?)?; + + // Neither can overflow as they will always be `<= self.storage.len()`. + let end = self.offset + len; + self.offset = end + 1; + + Some((ty, start..end)) + } +} + +fn notice_protocol_err() -> Error { + // https://github.com/launchbadge/sqlx/issues/1144 + Error::Protocol( + "Postgres returned a non-UTF-8 string for its error message. \ + This is most likely due to an error that occurred during authentication and \ + the default lc_messages locale is not binary-compatible with UTF-8. \ + See the server logs for the error details." + .into(), + ) +} + +#[test] +fn test_decode_error_response() { + const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0"; + + let m = Notice::decode(Bytes::from_static(DATA)).unwrap(); + + assert_eq!( + m.message(), + "extension \"uuid-ossp\" already exists, skipping" + ); + + assert!(matches!(m.severity(), PgSeverity::Notice)); + assert_eq!(m.code(), "42710"); +} + +#[cfg(all(test, not(debug_assertions)))] +#[bench] +fn bench_error_response_get_message(b: &mut test::Bencher) { + const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0"; + + let res = Notice::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); + + b.iter(|| { + let _ = test::black_box(&res).message(); + }); +} + +#[cfg(all(test, not(debug_assertions)))] +#[bench] +fn bench_decode_error_response(b: &mut test::Bencher) { + const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0"; + + b.iter(|| { + let _ = Notice::decode(test::black_box(Bytes::from_static(DATA))); + }); +} diff --git a/patches/sqlx-postgres/src/message/row_description.rs b/patches/sqlx-postgres/src/message/row_description.rs new file mode 100644 index 000000000..3f3155ed5 --- /dev/null +++ b/patches/sqlx-postgres/src/message/row_description.rs @@ -0,0 +1,99 @@ +use sqlx_core::bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::BufExt; +use crate::message::{BackendMessage, BackendMessageFormat}; +use crate::types::Oid; + +#[derive(Debug)] +pub struct RowDescription { + pub fields: Vec, +} + +#[derive(Debug)] +pub struct Field { + /// The name of the field. + pub name: String, + + /// If the field can be identified as a column of a specific table, the + /// object ID of the table; otherwise zero. + pub relation_id: Option, + + /// If the field can be identified as a column of a specific table, the attribute number of + /// the column; otherwise zero. + pub relation_attribute_no: Option, + + /// The object ID of the field's data type. + pub data_type_id: Oid, + + /// The data type size (see pg_type.typlen). Note that negative values denote + /// variable-width types. + #[allow(dead_code)] + pub data_type_size: i16, + + /// The type modifier (see pg_attribute.atttypmod). The meaning of the + /// modifier is type-specific. + #[allow(dead_code)] + pub type_modifier: i32, + + /// The format code being used for the field. + #[allow(dead_code)] + pub format: i16, +} + +impl BackendMessage for RowDescription { + const FORMAT: BackendMessageFormat = BackendMessageFormat::RowDescription; + + fn decode_body(mut buf: Bytes) -> Result { + if buf.len() < 2 { + return Err(err_protocol!( + "expected at least 2 bytes, got {}", + buf.len() + )); + } + + let cnt = buf.get_u16(); + let mut fields = Vec::with_capacity(cnt as usize); + + for _ in 0..cnt { + let name = buf.get_str_nul()?.to_owned(); + + if buf.len() < 18 { + return Err(err_protocol!( + "expected at least 18 bytes after field name {name:?}, got {}", + buf.len() + )); + } + + let relation_id = buf.get_i32(); + let relation_attribute_no = buf.get_i16(); + let data_type_id = Oid(buf.get_u32()); + let data_type_size = buf.get_i16(); + let type_modifier = buf.get_i32(); + let format = buf.get_i16(); + + fields.push(Field { + name, + relation_id: if relation_id == 0 { + None + } else { + Some(relation_id) + }, + relation_attribute_no: if relation_attribute_no == 0 { + None + } else { + Some(relation_attribute_no) + }, + data_type_id, + data_type_size, + type_modifier, + format, + }) + } + + Ok(Self { fields }) + } +} + +// TODO: Unit Test RowDescription +// TODO: Benchmark RowDescription diff --git a/patches/sqlx-postgres/src/message/sasl.rs b/patches/sqlx-postgres/src/message/sasl.rs new file mode 100644 index 000000000..9d393189b --- /dev/null +++ b/patches/sqlx-postgres/src/message/sasl.rs @@ -0,0 +1,69 @@ +use crate::io::BufMutExt; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; + +pub struct SaslInitialResponse<'a> { + pub response: &'a str, + pub plus: bool, +} + +impl SaslInitialResponse<'_> { + #[inline(always)] + fn selected_mechanism(&self) -> &'static str { + if self.plus { + "SCRAM-SHA-256-PLUS" + } else { + "SCRAM-SHA-256" + } + } +} + +impl FrontendMessage for SaslInitialResponse<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + + size += self.selected_mechanism().len(); + size += 1; // NUL terminator + + size += 4; // response_len + size += self.response.len(); + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + // name of the SASL authentication mechanism that the client selected + buf.put_str_nul(self.selected_mechanism()); + + let response_len = i32::try_from(self.response.len()).map_err(|_| { + err_protocol!( + "SASL Initial Response length too long for protcol: {}", + self.response.len() + ) + })?; + + buf.extend_from_slice(&response_len.to_be_bytes()); + buf.extend_from_slice(self.response.as_bytes()); + + Ok(()) + } +} + +pub struct SaslResponse<'a>(pub &'a str); + +impl FrontendMessage for SaslResponse<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic; + + fn body_size_hint(&self) -> Saturating { + Saturating(self.0.len()) + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + buf.extend(self.0.as_bytes()); + Ok(()) + } +} diff --git a/patches/sqlx-postgres/src/message/ssl_request.rs b/patches/sqlx-postgres/src/message/ssl_request.rs new file mode 100644 index 000000000..09c886221 --- /dev/null +++ b/patches/sqlx-postgres/src/message/ssl_request.rs @@ -0,0 +1,38 @@ +use crate::io::ProtocolEncode; + +pub struct SslRequest; + +impl SslRequest { + // https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-SSLREQUEST + pub const BYTES: &'static [u8] = b"\x00\x00\x00\x08\x04\xd2\x16\x2f"; +} + +// Cannot impl FrontendMessage because it does not have a format code +impl ProtocolEncode<'_> for SslRequest { + #[inline(always)] + fn encode_with(&self, buf: &mut Vec, _context: ()) -> Result<(), crate::Error> { + buf.extend_from_slice(Self::BYTES); + Ok(()) + } +} + +#[test] +fn test_encode_ssl_request() { + let mut buf = Vec::new(); + + // Int32(8) + // Length of message contents in bytes, including self. + buf.extend_from_slice(&8_u32.to_be_bytes()); + + // Int32(80877103) + // The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits, + // and 5679 in the least significant 16 bits. + // (To avoid confusion, this code must not be the same as any protocol version number.) + buf.extend_from_slice(&(((1234 << 16) | 5679) as u32).to_be_bytes()); + + let mut encoded = Vec::new(); + SslRequest.encode(&mut encoded).unwrap(); + + assert_eq!(buf, SslRequest::BYTES); + assert_eq!(buf, encoded); +} diff --git a/patches/sqlx-postgres/src/message/startup.rs b/patches/sqlx-postgres/src/message/startup.rs new file mode 100644 index 000000000..1c6d735ab --- /dev/null +++ b/patches/sqlx-postgres/src/message/startup.rs @@ -0,0 +1,96 @@ +use crate::io::PgBufMutExt; +use crate::io::{BufMutExt, ProtocolEncode}; + +// To begin a session, a frontend opens a connection to the server and sends a startup message. +// This message includes the names of the user and of the database the user wants to connect to; +// it also identifies the particular protocol version to be used. + +// Optionally, the startup message can include additional settings for run-time parameters. + +pub struct Startup<'a> { + /// The database user name to connect as. Required; there is no default. + pub username: Option<&'a str>, + + /// The database to connect to. Defaults to the user name. + pub database: Option<&'a str>, + + /// Additional start-up params. + /// + pub params: &'a [(&'a str, &'a str)], +} + +// Startup cannot impl FrontendMessage because it doesn't have a format code. +impl ProtocolEncode<'_> for Startup<'_> { + fn encode_with(&self, buf: &mut Vec, _context: ()) -> Result<(), crate::Error> { + buf.reserve(120); + + buf.put_length_prefixed(|buf| { + // The protocol version number. The most significant 16 bits are the + // major version number (3 for the protocol described here). The least + // significant 16 bits are the minor version number (0 + // for the protocol described here) + buf.extend(&196_608_i32.to_be_bytes()); + + if let Some(username) = self.username { + // The database user name to connect as. + encode_startup_param(buf, "user", username); + } + + if let Some(database) = self.database { + // The database to connect to. Defaults to the user name. + encode_startup_param(buf, "database", database); + } + + for (name, value) in self.params { + encode_startup_param(buf, name, value); + } + + // A zero byte is required as a terminator + // after the last name/value pair. + buf.push(0); + + Ok(()) + }) + } +} + +#[inline] +fn encode_startup_param(buf: &mut Vec, name: &str, value: &str) { + buf.put_str_nul(name); + buf.put_str_nul(value); +} + +#[test] +fn test_encode_startup() { + const EXPECTED: &[u8] = b"\0\0\0)\0\x03\0\0user\0postgres\0database\0postgres\0\0"; + + let mut buf = Vec::new(); + let m = Startup { + username: Some("postgres"), + database: Some("postgres"), + params: &[], + }; + + m.encode(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); +} + +#[cfg(all(test, not(debug_assertions)))] +#[bench] +fn bench_encode_startup(b: &mut test::Bencher) { + use test::black_box; + + let mut buf = Vec::with_capacity(128); + + b.iter(|| { + buf.clear(); + + black_box(Startup { + username: Some("postgres"), + database: Some("postgres"), + params: &[], + }) + .encode(&mut buf); + }); +} diff --git a/patches/sqlx-postgres/src/message/sync.rs b/patches/sqlx-postgres/src/message/sync.rs new file mode 100644 index 000000000..56f449874 --- /dev/null +++ b/patches/sqlx-postgres/src/message/sync.rs @@ -0,0 +1,20 @@ +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; + +#[derive(Debug)] +pub struct Sync; + +impl FrontendMessage for Sync { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Sync; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(0) + } + + #[inline(always)] + fn encode_body(&self, _buf: &mut Vec) -> Result<(), Error> { + Ok(()) + } +} diff --git a/patches/sqlx-postgres/src/message/terminate.rs b/patches/sqlx-postgres/src/message/terminate.rs new file mode 100644 index 000000000..39f8ff6e6 --- /dev/null +++ b/patches/sqlx-postgres/src/message/terminate.rs @@ -0,0 +1,19 @@ +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; + +pub struct Terminate; + +impl FrontendMessage for Terminate { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Terminate; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(0) + } + + #[inline(always)] + fn encode_body(&self, _buf: &mut Vec) -> Result<(), Error> { + Ok(()) + } +} diff --git a/patches/sqlx-postgres/src/migrate.rs b/patches/sqlx-postgres/src/migrate.rs new file mode 100644 index 000000000..da3080581 --- /dev/null +++ b/patches/sqlx-postgres/src/migrate.rs @@ -0,0 +1,314 @@ +use std::str::FromStr; +use std::time::Duration; +use std::time::Instant; + +use futures_core::future::BoxFuture; + +pub(crate) use sqlx_core::migrate::MigrateError; +pub(crate) use sqlx_core::migrate::{AppliedMigration, Migration}; +pub(crate) use sqlx_core::migrate::{Migrate, MigrateDatabase}; + +use crate::connection::{ConnectOptions, Connection}; +use crate::error::Error; +use crate::executor::Executor; +use crate::query::query; +use crate::query_as::query_as; +use crate::query_scalar::query_scalar; +use crate::{PgConnectOptions, PgConnection, Postgres}; + +fn parse_for_maintenance(url: &str) -> Result<(PgConnectOptions, String), Error> { + let mut options = PgConnectOptions::from_str(url)?; + + // pull out the name of the database to create + let database = options + .database + .as_deref() + .unwrap_or(&options.username) + .to_owned(); + + // switch us to the maintenance database + // use `postgres` _unless_ the database is postgres, in which case, use `template1` + // this matches the behavior of the `createdb` util + options.database = if database == "postgres" { + Some("template1".into()) + } else { + Some("postgres".into()) + }; + + Ok((options, database)) +} + +impl MigrateDatabase for Postgres { + fn create_database(url: &str) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + let (options, database) = parse_for_maintenance(url)?; + let mut conn = options.connect().await?; + + let _ = conn + .execute(&*format!( + "CREATE DATABASE \"{}\"", + database.replace('"', "\"\"") + )) + .await?; + + Ok(()) + }) + } + + fn database_exists(url: &str) -> BoxFuture<'_, Result> { + Box::pin(async move { + let (options, database) = parse_for_maintenance(url)?; + let mut conn = options.connect().await?; + + let exists: bool = + query_scalar("select exists(SELECT 1 from pg_database WHERE datname = $1)") + .bind(database) + .fetch_one(&mut conn) + .await?; + + Ok(exists) + }) + } + + fn drop_database(url: &str) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + let (options, database) = parse_for_maintenance(url)?; + let mut conn = options.connect().await?; + + let _ = conn + .execute(&*format!( + "DROP DATABASE IF EXISTS \"{}\"", + database.replace('"', "\"\"") + )) + .await?; + + Ok(()) + }) + } + + fn force_drop_database(url: &str) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + let (options, database) = parse_for_maintenance(url)?; + let mut conn = options.connect().await?; + + let row: (String,) = query_as("SELECT current_setting('server_version_num')") + .fetch_one(&mut conn) + .await?; + + let version = row.0.parse::().unwrap(); + + let pid_type = if version >= 90200 { "pid" } else { "procpid" }; + + conn.execute(&*format!( + "SELECT pg_terminate_backend(pg_stat_activity.{pid_type}) FROM pg_stat_activity \ + WHERE pg_stat_activity.datname = '{database}' AND {pid_type} <> pg_backend_pid()" + )) + .await?; + + Self::drop_database(url).await + }) + } +} + +impl Migrate for PgConnection { + fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + Box::pin(async move { + // language=SQL + self.execute( + r#" +CREATE TABLE IF NOT EXISTS _sqlx_migrations ( + version BIGINT PRIMARY KEY, + description TEXT NOT NULL, + installed_on TIMESTAMPTZ NOT NULL DEFAULT now(), + success BOOLEAN NOT NULL, + checksum BYTEA NOT NULL, + execution_time BIGINT NOT NULL +); + "#, + ) + .await?; + + Ok(()) + }) + } + + fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { + Box::pin(async move { + // language=SQL + let row: Option<(i64,)> = query_as( + "SELECT version FROM _sqlx_migrations WHERE success = false ORDER BY version LIMIT 1", + ) + .fetch_optional(self) + .await?; + + Ok(row.map(|r| r.0)) + }) + } + + fn list_applied_migrations( + &mut self, + ) -> BoxFuture<'_, Result, MigrateError>> { + Box::pin(async move { + // language=SQL + let rows: Vec<(i64, Vec)> = + query_as("SELECT version, checksum FROM _sqlx_migrations ORDER BY version") + .fetch_all(self) + .await?; + + let migrations = rows + .into_iter() + .map(|(version, checksum)| AppliedMigration { + version, + checksum: checksum.into(), + }) + .collect(); + + Ok(migrations) + }) + } + + fn lock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + Box::pin(async move { + let database_name = current_database(self).await?; + let lock_id = generate_lock_id(&database_name); + + // create an application lock over the database + // this function will not return until the lock is acquired + + // https://www.postgresql.org/docs/current/explicit-locking.html#ADVISORY-LOCKS + // https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-ADVISORY-LOCKS-TABLE + + // language=SQL + let _ = query("SELECT pg_advisory_lock($1)") + .bind(lock_id) + .execute(self) + .await?; + + Ok(()) + }) + } + + fn unlock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + Box::pin(async move { + let database_name = current_database(self).await?; + let lock_id = generate_lock_id(&database_name); + + // language=SQL + let _ = query("SELECT pg_advisory_unlock($1)") + .bind(lock_id) + .execute(self) + .await?; + + Ok(()) + }) + } + + fn apply<'e: 'm, 'm>( + &'e mut self, + migration: &'m Migration, + ) -> BoxFuture<'m, Result> { + Box::pin(async move { + let start = Instant::now(); + + // execute migration queries + if migration.no_tx { + execute_migration(self, migration).await?; + } else { + // Use a single transaction for the actual migration script and the essential bookeeping so we never + // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. + // The `execution_time` however can only be measured for the whole transaction. This value _only_ exists for + // data lineage and debugging reasons, so it is not super important if it is lost. So we initialize it to -1 + // and update it once the actual transaction completed. + let mut tx = self.begin().await?; + execute_migration(&mut tx, migration).await?; + tx.commit().await?; + } + + // Update `elapsed_time`. + // NOTE: The process may disconnect/die at this point, so the elapsed time value might be lost. We accept + // this small risk since this value is not super important. + let elapsed = start.elapsed(); + + // language=SQL + #[allow(clippy::cast_possible_truncation)] + let _ = query( + r#" + UPDATE _sqlx_migrations + SET execution_time = $1 + WHERE version = $2 + "#, + ) + .bind(elapsed.as_nanos() as i64) + .bind(migration.version) + .execute(self) + .await?; + + Ok(elapsed) + }) + } + + fn revert<'e: 'm, 'm>( + &'e mut self, + migration: &'m Migration, + ) -> BoxFuture<'m, Result> { + Box::pin(async move { + // Use a single transaction for the actual migration script and the essential bookeeping so we never + // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. + let mut tx = self.begin().await?; + let start = Instant::now(); + + let _ = tx.execute(&*migration.sql).await?; + + // language=SQL + let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = $1"#) + .bind(migration.version) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + + let elapsed = start.elapsed(); + + Ok(elapsed) + }) + } +} + +async fn execute_migration( + conn: &mut PgConnection, + migration: &Migration, +) -> Result<(), MigrateError> { + let _ = conn + .execute(&*migration.sql) + .await + .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; + + // language=SQL + let _ = query( + r#" + INSERT INTO _sqlx_migrations ( version, description, success, checksum, execution_time ) + VALUES ( $1, $2, TRUE, $3, -1 ) + "#, + ) + .bind(migration.version) + .bind(&*migration.description) + .bind(&*migration.checksum) + .execute(conn) + .await?; + + Ok(()) +} + +async fn current_database(conn: &mut PgConnection) -> Result { + // language=SQL + Ok(query_scalar("SELECT current_database()") + .fetch_one(conn) + .await?) +} + +// inspired from rails: https://github.com/rails/rails/blob/6e49cc77ab3d16c06e12f93158eaf3e507d4120e/activerecord/lib/active_record/migration.rb#L1308 +fn generate_lock_id(database_name: &str) -> i64 { + const CRC_IEEE: crc::Crc = crc::Crc::::new(&crc::CRC_32_ISO_HDLC); + // 0x3d32ad9e chosen by fair dice roll + 0x3d32ad9e * (CRC_IEEE.checksum(database_name.as_bytes()) as i64) +} diff --git a/patches/sqlx-postgres/src/options/connect.rs b/patches/sqlx-postgres/src/options/connect.rs new file mode 100644 index 000000000..bc6e4adce --- /dev/null +++ b/patches/sqlx-postgres/src/options/connect.rs @@ -0,0 +1,36 @@ +use crate::connection::ConnectOptions; +use crate::error::Error; +use crate::{PgConnectOptions, PgConnection}; +use futures_core::future::BoxFuture; +use log::LevelFilter; +use sqlx_core::Url; +use std::time::Duration; + +impl ConnectOptions for PgConnectOptions { + type Connection = PgConnection; + + fn from_url(url: &Url) -> Result { + Self::parse_from_url(url) + } + + fn to_url_lossy(&self) -> Url { + self.build_url() + } + + fn connect(&self) -> BoxFuture<'_, Result> + where + Self::Connection: Sized, + { + Box::pin(PgConnection::establish(self)) + } + + fn log_statements(mut self, level: LevelFilter) -> Self { + self.log_settings.log_statements(level); + self + } + + fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self { + self.log_settings.log_slow_statements(level, duration); + self + } +} diff --git a/patches/sqlx-postgres/src/options/mod.rs b/patches/sqlx-postgres/src/options/mod.rs new file mode 100644 index 000000000..a0b222606 --- /dev/null +++ b/patches/sqlx-postgres/src/options/mod.rs @@ -0,0 +1,688 @@ +use std::borrow::Cow; +use std::env::var; +use std::fmt::{Display, Write}; +use std::path::{Path, PathBuf}; + +pub use ssl_mode::PgSslMode; + +use crate::{connection::LogSettings, net::tls::CertificateInput}; + +mod connect; +mod parse; +mod pgpass; +mod ssl_mode; + +/// Options and flags which can be used to configure a PostgreSQL connection. +/// +/// A value of `PgConnectOptions` can be parsed from a connection URL, +/// as described by [libpq](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING). +/// +/// The general form for a connection URL is: +/// +/// ```text +/// postgresql://[user[:password]@][host][:port][/dbname][?param1=value1&...] +/// ``` +/// +/// This type also implements [`FromStr`][std::str::FromStr] so you can parse it from a string +/// containing a connection URL and then further adjust options if necessary (see example below). +/// +/// ## Parameters +/// +/// |Parameter|Default|Description| +/// |---------|-------|-----------| +/// | `sslmode` | `prefer` | Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated. See [`PgSslMode`]. | +/// | `sslrootcert` | `None` | Sets the name of a file containing a list of trusted SSL Certificate Authorities. | +/// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. Set to `0` to disable. | +/// | `host` | `None` | Path to the directory containing a PostgreSQL unix domain socket, which will be used instead of TCP if set. | +/// | `hostaddr` | `None` | Same as `host`, but only accepts IP addresses. | +/// | `application-name` | `None` | The name will be displayed in the pg_stat_activity view and included in CSV log entries. | +/// | `user` | result of `whoami` | PostgreSQL user name to connect as. | +/// | `password` | `None` | Password to be used if the server demands password authentication. | +/// | `port` | `5432` | Port number to connect to at the server host, or socket file name extension for Unix-domain connections. | +/// | `dbname` | `None` | The database name. | +/// | `options` | `None` | The runtime parameters to send to the server at connection start. | +/// +/// The URL scheme designator can be either `postgresql://` or `postgres://`. +/// Each of the URL parts is optional. +/// +/// ```text +/// postgresql:// +/// postgresql://localhost +/// postgresql://localhost:5433 +/// postgresql://localhost/mydb +/// postgresql://user@localhost +/// postgresql://user:secret@localhost +/// postgresql://localhost?dbname=mydb&user=postgres&password=postgres +/// ``` +/// +/// # Example +/// +/// ```rust,no_run +/// use sqlx::{Connection, ConnectOptions}; +/// use sqlx::postgres::{PgConnectOptions, PgConnection, PgPool, PgSslMode}; +/// +/// # async fn example() -> sqlx::Result<()> { +/// // URL connection string +/// let conn = PgConnection::connect("postgres://localhost/mydb").await?; +/// +/// // Manually-constructed options +/// let conn = PgConnectOptions::new() +/// .host("secret-host") +/// .port(2525) +/// .username("secret-user") +/// .password("secret-password") +/// .ssl_mode(PgSslMode::Require) +/// .connect() +/// .await?; +/// +/// // Modifying options parsed from a string +/// let mut opts: PgConnectOptions = "postgres://localhost/mydb".parse()?; +/// +/// // Change the log verbosity level for queries. +/// // Information about SQL queries is logged at `DEBUG` level by default. +/// opts = opts.log_statements(log::LevelFilter::Trace); +/// +/// let pool = PgPool::connect_with(opts).await?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct PgConnectOptions { + pub(crate) host: String, + pub(crate) port: u16, + pub(crate) socket: Option, + pub(crate) username: String, + pub(crate) password: Option, + pub(crate) database: Option, + pub(crate) ssl_mode: PgSslMode, + pub(crate) ssl_root_cert: Option, + pub(crate) ssl_client_cert: Option, + pub(crate) ssl_client_key: Option, + pub(crate) statement_cache_capacity: usize, + pub(crate) application_name: Option, + pub(crate) log_settings: LogSettings, + pub(crate) extra_float_digits: Option>, + pub(crate) options: Option, +} + +impl Default for PgConnectOptions { + fn default() -> Self { + Self::new_without_pgpass().apply_pgpass() + } +} + +impl PgConnectOptions { + /// Creates a new, default set of options ready for configuration. + /// + /// By default, this reads the following environment variables and sets their + /// equivalent options. + /// + /// * `PGHOST` + /// * `PGPORT` + /// * `PGUSER` + /// * `PGPASSWORD` + /// * `PGDATABASE` + /// * `PGSSLROOTCERT` + /// * `PGSSLCERT` + /// * `PGSSLKEY` + /// * `PGSSLMODE` + /// * `PGAPPNAME` + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new(); + /// ``` + pub fn new() -> Self { + Self::new_without_pgpass().apply_pgpass() + } + + pub fn new_without_pgpass() -> Self { + let port = var("PGPORT") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(5432); + + let host = var("PGHOST").ok().unwrap_or_else(|| default_host(port)); + + let username = var("PGUSER").ok().unwrap_or_else(whoami::username); + + let database = var("PGDATABASE").ok(); + + PgConnectOptions { + port, + host, + socket: None, + username, + password: var("PGPASSWORD").ok(), + database, + ssl_root_cert: var("PGSSLROOTCERT").ok().map(CertificateInput::from), + ssl_client_cert: var("PGSSLCERT").ok().map(CertificateInput::from), + ssl_client_key: var("PGSSLKEY").ok().map(CertificateInput::from), + ssl_mode: var("PGSSLMODE") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or_default(), + statement_cache_capacity: 100, + application_name: var("PGAPPNAME").ok(), + extra_float_digits: Some("2".into()), + log_settings: Default::default(), + options: var("PGOPTIONS").ok(), + } + } + + pub(crate) fn apply_pgpass(mut self) -> Self { + if self.password.is_none() { + self.password = pgpass::load_password( + &self.host, + self.port, + &self.username, + self.database.as_deref(), + ); + } + + self + } + + /// Sets the name of the host to connect to. + /// + /// If a host name begins with a slash, it specifies + /// Unix-domain communication rather than TCP/IP communication; the value is the name of + /// the directory in which the socket file is stored. + /// + /// The default behavior when host is not specified, or is empty, + /// is to connect to a Unix-domain socket + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .host("localhost"); + /// ``` + pub fn host(mut self, host: &str) -> Self { + host.clone_into(&mut self.host); + self + } + + /// Sets the port to connect to at the server host. + /// + /// The default port for PostgreSQL is `5432`. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .port(5432); + /// ``` + pub fn port(mut self, port: u16) -> Self { + self.port = port; + self + } + + /// Sets a custom path to a directory containing a unix domain socket, + /// switching the connection method from TCP to the corresponding socket. + /// + /// By default set to `None`. + pub fn socket(mut self, path: impl AsRef) -> Self { + self.socket = Some(path.as_ref().to_path_buf()); + self + } + + /// Sets the username to connect as. + /// + /// Defaults to be the same as the operating system name of + /// the user running the application. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .username("postgres"); + /// ``` + pub fn username(mut self, username: &str) -> Self { + username.clone_into(&mut self.username); + self + } + + /// Sets the password to use if the server demands password authentication. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .username("root") + /// .password("safe-and-secure"); + /// ``` + pub fn password(mut self, password: &str) -> Self { + self.password = Some(password.to_owned()); + self + } + + /// Sets the database name. Defaults to be the same as the user name. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .database("postgres"); + /// ``` + pub fn database(mut self, database: &str) -> Self { + self.database = Some(database.to_owned()); + self + } + + /// Sets whether or with what priority a secure SSL TCP/IP connection will be negotiated + /// with the server. + /// + /// By default, the SSL mode is [`Prefer`](PgSslMode::Prefer), and the client will + /// first attempt an SSL connection but fallback to a non-SSL connection on failure. + /// + /// Ignored for Unix domain socket communication. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; + /// let options = PgConnectOptions::new() + /// .ssl_mode(PgSslMode::Require); + /// ``` + pub fn ssl_mode(mut self, mode: PgSslMode) -> Self { + self.ssl_mode = mode; + self + } + + /// Sets the name of a file containing SSL certificate authority (CA) certificate(s). + /// If the file exists, the server's certificate will be verified to be signed by + /// one of these authorities. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; + /// let options = PgConnectOptions::new() + /// // Providing a CA certificate with less than VerifyCa is pointless + /// .ssl_mode(PgSslMode::VerifyCa) + /// .ssl_root_cert("./ca-certificate.crt"); + /// ``` + pub fn ssl_root_cert(mut self, cert: impl AsRef) -> Self { + self.ssl_root_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf())); + self + } + + /// Sets the name of a file containing SSL client certificate. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; + /// let options = PgConnectOptions::new() + /// // Providing a CA certificate with less than VerifyCa is pointless + /// .ssl_mode(PgSslMode::VerifyCa) + /// .ssl_client_cert("./client.crt"); + /// ``` + pub fn ssl_client_cert(mut self, cert: impl AsRef) -> Self { + self.ssl_client_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf())); + self + } + + /// Sets the SSL client certificate as a PEM-encoded byte slice. + /// + /// This should be an ASCII-encoded blob that starts with `-----BEGIN CERTIFICATE-----`. + /// + /// # Example + /// Note: embedding SSL certificates and keys in the binary is not advised. + /// This is for illustration purposes only. + /// + /// ```rust + /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; + /// + /// const CERT: &[u8] = b"\ + /// -----BEGIN CERTIFICATE----- + /// + /// -----END CERTIFICATE-----"; + /// + /// let options = PgConnectOptions::new() + /// // Providing a CA certificate with less than VerifyCa is pointless + /// .ssl_mode(PgSslMode::VerifyCa) + /// .ssl_client_cert_from_pem(CERT); + /// ``` + pub fn ssl_client_cert_from_pem(mut self, cert: impl AsRef<[u8]>) -> Self { + self.ssl_client_cert = Some(CertificateInput::Inline(cert.as_ref().to_vec())); + self + } + + /// Sets the name of a file containing SSL client key. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; + /// let options = PgConnectOptions::new() + /// // Providing a CA certificate with less than VerifyCa is pointless + /// .ssl_mode(PgSslMode::VerifyCa) + /// .ssl_client_key("./client.key"); + /// ``` + pub fn ssl_client_key(mut self, key: impl AsRef) -> Self { + self.ssl_client_key = Some(CertificateInput::File(key.as_ref().to_path_buf())); + self + } + + /// Sets the SSL client key as a PEM-encoded byte slice. + /// + /// This should be an ASCII-encoded blob that starts with `-----BEGIN PRIVATE KEY-----`. + /// + /// # Example + /// Note: embedding SSL certificates and keys in the binary is not advised. + /// This is for illustration purposes only. + /// + /// ```rust + /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; + /// + /// const KEY: &[u8] = b"\ + /// -----BEGIN PRIVATE KEY----- + /// + /// -----END PRIVATE KEY-----"; + /// + /// let options = PgConnectOptions::new() + /// // Providing a CA certificate with less than VerifyCa is pointless + /// .ssl_mode(PgSslMode::VerifyCa) + /// .ssl_client_key_from_pem(KEY); + /// ``` + pub fn ssl_client_key_from_pem(mut self, key: impl AsRef<[u8]>) -> Self { + self.ssl_client_key = Some(CertificateInput::Inline(key.as_ref().to_vec())); + self + } + + /// Sets PEM encoded trusted SSL Certificate Authorities (CA). + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; + /// let options = PgConnectOptions::new() + /// // Providing a CA certificate with less than VerifyCa is pointless + /// .ssl_mode(PgSslMode::VerifyCa) + /// .ssl_root_cert_from_pem(vec![]); + /// ``` + pub fn ssl_root_cert_from_pem(mut self, pem_certificate: Vec) -> Self { + self.ssl_root_cert = Some(CertificateInput::Inline(pem_certificate)); + self + } + + /// Sets the capacity of the connection's statement cache in a number of stored + /// distinct statements. Caching is handled using LRU, meaning when the + /// amount of queries hits the defined limit, the oldest statement will get + /// dropped. + /// + /// The default cache capacity is 100 statements. + pub fn statement_cache_capacity(mut self, capacity: usize) -> Self { + self.statement_cache_capacity = capacity; + self + } + + /// Sets the application name. Defaults to None + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .application_name("my-app"); + /// ``` + pub fn application_name(mut self, application_name: &str) -> Self { + self.application_name = Some(application_name.to_owned()); + self + } + + /// Sets or removes the `extra_float_digits` connection option. + /// + /// This changes the default precision of floating-point values returned in text mode (when + /// not using prepared statements such as calling methods of [`Executor`] directly). + /// + /// Historically, Postgres would by default round floating-point values to 6 and 15 digits + /// for `float4`/`REAL` (`f32`) and `float8`/`DOUBLE` (`f64`), respectively, which would mean + /// that the returned value may not be exactly the same as its representation in Postgres. + /// + /// The nominal range for this value is `-15` to `3`, where negative values for this option + /// cause floating-points to be rounded to that many fewer digits than normal (`-1` causes + /// `float4` to be rounded to 5 digits instead of six, or 14 instead of 15 for `float8`), + /// positive values cause Postgres to emit that many extra digits of precision over default + /// (or simply use maximum precision in Postgres 12 and later), + /// and 0 means keep the default behavior (or the "old" behavior described above + /// as of Postgres 12). + /// + /// SQLx sets this value to 3 by default, which tells Postgres to return floating-point values + /// at their maximum precision in the hope that the parsed value will be identical to its + /// counterpart in Postgres. This is also the default in Postgres 12 and later anyway. + /// + /// However, older versions of Postgres and alternative implementations that talk the Postgres + /// protocol may not support this option, or the full range of values. + /// + /// If you get an error like "unknown option `extra_float_digits`" when connecting, try + /// setting this to `None` or consult the manual of your database for the allowed range + /// of values. + /// + /// For more information, see: + /// * [Postgres manual, 20.11.2: Client Connection Defaults; Locale and Formatting][20.11.2] + /// * [Postgres manual, 8.1.3: Numeric Types; Floating-point Types][8.1.3] + /// + /// [`Executor`]: crate::executor::Executor + /// [20.11.2]: https://www.postgresql.org/docs/current/runtime-config-client.html#RUNTIME-CONFIG-CLIENT-FORMAT + /// [8.1.3]: https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-FLOAT + /// + /// ### Examples + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// + /// let mut options = PgConnectOptions::new() + /// // for Redshift and Postgres 10 + /// .extra_float_digits(2); + /// + /// let mut options = PgConnectOptions::new() + /// // don't send the option at all (Postgres 9 and older) + /// .extra_float_digits(None); + /// ``` + pub fn extra_float_digits(mut self, extra_float_digits: impl Into>) -> Self { + self.extra_float_digits = extra_float_digits.into().map(|it| it.to_string().into()); + self + } + + /// Set additional startup options for the connection as a list of key-value pairs. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .options([("geqo", "off"), ("statement_timeout", "5min")]); + /// ``` + pub fn options(mut self, options: I) -> Self + where + K: Display, + V: Display, + I: IntoIterator, + { + // Do this in here so `options_str` is only set if we have an option to insert + let options_str = self.options.get_or_insert_with(String::new); + for (k, v) in options { + if !options_str.is_empty() { + options_str.push(' '); + } + + write!(options_str, "-c {k}={v}").expect("failed to write an option to the string"); + } + self + } + + /// We try using a socket if hostname starts with `/` or if socket parameter + /// is specified. + pub(crate) fn fetch_socket(&self) -> Option { + match self.socket { + Some(ref socket) => { + let full_path = format!("{}/.s.PGSQL.{}", socket.display(), self.port); + Some(full_path) + } + None if self.host.starts_with('/') => { + let full_path = format!("{}/.s.PGSQL.{}", self.host, self.port); + Some(full_path) + } + _ => None, + } + } +} + +impl PgConnectOptions { + /// Get the current host. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .host("127.0.0.1"); + /// assert_eq!(options.get_host(), "127.0.0.1"); + /// ``` + pub fn get_host(&self) -> &str { + &self.host + } + + /// Get the server's port. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .port(6543); + /// assert_eq!(options.get_port(), 6543); + /// ``` + pub fn get_port(&self) -> u16 { + self.port + } + + /// Get the socket path. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .socket("/tmp"); + /// assert!(options.get_socket().is_some()); + /// ``` + pub fn get_socket(&self) -> Option<&PathBuf> { + self.socket.as_ref() + } + + /// Get the server's port. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .username("foo"); + /// assert_eq!(options.get_username(), "foo"); + /// ``` + pub fn get_username(&self) -> &str { + &self.username + } + + /// Get the current database name. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .database("postgres"); + /// assert!(options.get_database().is_some()); + /// ``` + pub fn get_database(&self) -> Option<&str> { + self.database.as_deref() + } + + /// Get the SSL mode. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::{PgConnectOptions, PgSslMode}; + /// let options = PgConnectOptions::new(); + /// assert!(matches!(options.get_ssl_mode(), PgSslMode::Prefer)); + /// ``` + pub fn get_ssl_mode(&self) -> PgSslMode { + self.ssl_mode + } + + /// Get the application name. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .application_name("service"); + /// assert!(options.get_application_name().is_some()); + /// ``` + pub fn get_application_name(&self) -> Option<&str> { + self.application_name.as_deref() + } + + /// Get the options. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .options([("foo", "bar")]); + /// assert!(options.get_options().is_some()); + /// ``` + pub fn get_options(&self) -> Option<&str> { + self.options.as_deref() + } +} + +fn default_host(port: u16) -> String { + // try to check for the existence of a unix socket and uses that + let socket = format!(".s.PGSQL.{port}"); + let candidates = [ + "/var/run/postgresql", // Debian + "/private/tmp", // OSX (homebrew) + "/tmp", // Default + ]; + + for candidate in &candidates { + if Path::new(candidate).join(&socket).exists() { + return candidate.to_string(); + } + } + + // fallback to localhost if no socket was found + "localhost".to_owned() +} + +#[test] +fn test_options_formatting() { + let options = PgConnectOptions::new().options([("geqo", "off")]); + assert_eq!(options.options, Some("-c geqo=off".to_string())); + let options = options.options([("search_path", "sqlx")]); + assert_eq!( + options.options, + Some("-c geqo=off -c search_path=sqlx".to_string()) + ); + let options = PgConnectOptions::new().options([("geqo", "off"), ("statement_timeout", "5min")]); + assert_eq!( + options.options, + Some("-c geqo=off -c statement_timeout=5min".to_string()) + ); + let options = PgConnectOptions::new(); + assert_eq!(options.options, None); +} diff --git a/patches/sqlx-postgres/src/options/parse.rs b/patches/sqlx-postgres/src/options/parse.rs new file mode 100644 index 000000000..104001007 --- /dev/null +++ b/patches/sqlx-postgres/src/options/parse.rs @@ -0,0 +1,338 @@ +use crate::error::Error; +use crate::{PgConnectOptions, PgSslMode}; +use sqlx_core::percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; +use sqlx_core::Url; +use std::net::IpAddr; +use std::str::FromStr; + +impl PgConnectOptions { + pub(crate) fn parse_from_url(url: &Url) -> Result { + let mut options = Self::new_without_pgpass(); + + if let Some(host) = url.host_str() { + let host_decoded = percent_decode_str(host); + options = match host_decoded.clone().next() { + Some(b'/') => options.socket(&*host_decoded.decode_utf8().map_err(Error::config)?), + _ => options.host(host), + } + } + + if let Some(port) = url.port() { + options = options.port(port); + } + + let username = url.username(); + if !username.is_empty() { + options = options.username( + &percent_decode_str(username) + .decode_utf8() + .map_err(Error::config)?, + ); + } + + if let Some(password) = url.password() { + options = options.password( + &percent_decode_str(password) + .decode_utf8() + .map_err(Error::config)?, + ); + } + + let path = url.path().trim_start_matches('/'); + if !path.is_empty() { + options = options.database(path); + } + + for (key, value) in url.query_pairs().into_iter() { + match &*key { + "sslmode" | "ssl-mode" => { + options = options.ssl_mode(value.parse().map_err(Error::config)?); + } + + "sslrootcert" | "ssl-root-cert" | "ssl-ca" => { + options = options.ssl_root_cert(&*value); + } + + "sslcert" | "ssl-cert" => options = options.ssl_client_cert(&*value), + + "sslkey" | "ssl-key" => options = options.ssl_client_key(&*value), + + "statement-cache-capacity" => { + options = + options.statement_cache_capacity(value.parse().map_err(Error::config)?); + } + + "host" => { + if value.starts_with('/') { + options = options.socket(&*value); + } else { + options = options.host(&value); + } + } + + "hostaddr" => { + value.parse::().map_err(Error::config)?; + options = options.host(&value) + } + + "port" => options = options.port(value.parse().map_err(Error::config)?), + + "dbname" => options = options.database(&value), + + "user" => options = options.username(&value), + + "password" => options = options.password(&value), + + "application_name" => options = options.application_name(&value), + + "options" => { + if let Some(options) = options.options.as_mut() { + options.push(' '); + options.push_str(&value); + } else { + options.options = Some(value.to_string()); + } + } + + k if k.starts_with("options[") => { + if let Some(key) = k.strip_prefix("options[").unwrap().strip_suffix(']') { + options = options.options([(key, &*value)]); + } + } + + _ => tracing::warn!(%key, %value, "ignoring unrecognized connect parameter"), + } + } + + let options = options.apply_pgpass(); + + Ok(options) + } + + pub(crate) fn build_url(&self) -> Url { + let host = match &self.socket { + Some(socket) => { + utf8_percent_encode(&socket.to_string_lossy(), NON_ALPHANUMERIC).to_string() + } + None => self.host.to_owned(), + }; + + let mut url = Url::parse(&format!( + "postgres://{}@{}:{}", + self.username, host, self.port + )) + .expect("BUG: generated un-parseable URL"); + + if let Some(password) = &self.password { + let password = utf8_percent_encode(password, NON_ALPHANUMERIC).to_string(); + let _ = url.set_password(Some(&password)); + } + + if let Some(database) = &self.database { + url.set_path(database); + } + + let ssl_mode = match self.ssl_mode { + PgSslMode::Allow => "allow", + PgSslMode::Disable => "disable", + PgSslMode::Prefer => "prefer", + PgSslMode::Require => "require", + PgSslMode::VerifyCa => "verify-ca", + PgSslMode::VerifyFull => "verify-full", + }; + url.query_pairs_mut().append_pair("sslmode", ssl_mode); + + if let Some(ssl_root_cert) = &self.ssl_root_cert { + url.query_pairs_mut() + .append_pair("sslrootcert", &ssl_root_cert.to_string()); + } + + if let Some(ssl_client_cert) = &self.ssl_client_cert { + url.query_pairs_mut() + .append_pair("sslcert", &ssl_client_cert.to_string()); + } + + if let Some(ssl_client_key) = &self.ssl_client_key { + url.query_pairs_mut() + .append_pair("sslkey", &ssl_client_key.to_string()); + } + + url.query_pairs_mut().append_pair( + "statement-cache-capacity", + &self.statement_cache_capacity.to_string(), + ); + + url + } +} + +impl FromStr for PgConnectOptions { + type Err = Error; + + fn from_str(s: &str) -> Result { + let url: Url = s.parse().map_err(Error::config)?; + + Self::parse_from_url(&url) + } +} + +#[test] +fn it_parses_socket_correctly_from_parameter() { + let url = "postgres:///?host=/var/run/postgres/"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!(Some("/var/run/postgres/".into()), opts.socket); +} + +#[test] +fn it_parses_host_correctly_from_parameter() { + let url = "postgres:///?host=google.database.com"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!(None, opts.socket); + assert_eq!("google.database.com", &opts.host); +} + +#[test] +fn it_parses_hostaddr_correctly_from_parameter() { + let url = "postgres:///?hostaddr=8.8.8.8"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!(None, opts.socket); + assert_eq!("8.8.8.8", &opts.host); +} + +#[test] +fn it_parses_port_correctly_from_parameter() { + let url = "postgres:///?port=1234"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!(None, opts.socket); + assert_eq!(1234, opts.port); +} + +#[test] +fn it_parses_dbname_correctly_from_parameter() { + let url = "postgres:///?dbname=some_db"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!(None, opts.socket); + assert_eq!(Some("some_db"), opts.database.as_deref()); +} + +#[test] +fn it_parses_user_correctly_from_parameter() { + let url = "postgres:///?user=some_user"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!(None, opts.socket); + assert_eq!("some_user", opts.username); +} + +#[test] +fn it_parses_password_correctly_from_parameter() { + let url = "postgres:///?password=some_pass"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!(None, opts.socket); + assert_eq!(Some("some_pass"), opts.password.as_deref()); +} + +#[test] +fn it_parses_application_name_correctly_from_parameter() { + let url = "postgres:///?application_name=some_name"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!(Some("some_name"), opts.application_name.as_deref()); +} + +#[test] +fn it_parses_username_with_at_sign_correctly() { + let url = "postgres://user@hostname:password@hostname:5432/database"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!("user@hostname", &opts.username); +} + +#[test] +fn it_parses_password_with_non_ascii_chars_correctly() { + let url = "postgres://username:p@ssw0rd@hostname:5432/database"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!(Some("p@ssw0rd".into()), opts.password); +} + +#[test] +fn it_parses_socket_correctly_percent_encoded() { + let url = "postgres://%2Fvar%2Flib%2Fpostgres/database"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!(Some("/var/lib/postgres/".into()), opts.socket); +} +#[test] +fn it_parses_socket_correctly_with_username_percent_encoded() { + let url = "postgres://some_user@%2Fvar%2Flib%2Fpostgres/database"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!("some_user", opts.username); + assert_eq!(Some("/var/lib/postgres/".into()), opts.socket); + assert_eq!(Some("database"), opts.database.as_deref()); +} +#[test] +fn it_parses_libpq_options_correctly() { + let url = "postgres:///?options=-c%20synchronous_commit%3Doff%20--search_path%3Dpostgres"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!( + Some("-c synchronous_commit=off --search_path=postgres".into()), + opts.options + ); +} +#[test] +fn it_parses_sqlx_options_correctly() { + let url = "postgres:///?options[synchronous_commit]=off&options[search_path]=postgres"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!( + Some("-c synchronous_commit=off -c search_path=postgres".into()), + opts.options + ); +} + +#[test] +fn it_returns_the_parsed_url_when_socket() { + let url = "postgres://username@%2Fvar%2Flib%2Fpostgres/database"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + let mut expected_url = Url::parse(url).unwrap(); + // PgConnectOptions defaults + let query_string = "sslmode=prefer&statement-cache-capacity=100"; + let port = 5432; + expected_url.set_query(Some(query_string)); + let _ = expected_url.set_port(Some(port)); + + assert_eq!(expected_url, opts.build_url()); +} + +#[test] +fn it_returns_the_parsed_url_when_host() { + let url = "postgres://username:p@ssw0rd@hostname:5432/database"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + let mut expected_url = Url::parse(url).unwrap(); + // PgConnectOptions defaults + let query_string = "sslmode=prefer&statement-cache-capacity=100"; + expected_url.set_query(Some(query_string)); + + assert_eq!(expected_url, opts.build_url()); +} + +#[test] +fn built_url_can_be_parsed() { + let url = "postgres://username:p@ssw0rd@hostname:5432/database"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + let parsed = PgConnectOptions::from_str(&opts.build_url().to_string()); + + assert!(parsed.is_ok()); +} diff --git a/patches/sqlx-postgres/src/options/pgpass.rs b/patches/sqlx-postgres/src/options/pgpass.rs new file mode 100644 index 000000000..49da460da --- /dev/null +++ b/patches/sqlx-postgres/src/options/pgpass.rs @@ -0,0 +1,341 @@ +use std::borrow::Cow; +use std::env::var_os; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::PathBuf; + +/// try to load a password from the various pgpass file locations +pub fn load_password( + host: &str, + port: u16, + username: &str, + database: Option<&str>, +) -> Option { + let custom_file = var_os("PGPASSFILE"); + if let Some(file) = custom_file { + if let Some(password) = + load_password_from_file(PathBuf::from(file), host, port, username, database) + { + return Some(password); + } + } + + #[cfg(not(target_os = "windows"))] + let default_file = home::home_dir().map(|path| path.join(".pgpass")); + #[cfg(target_os = "windows")] + let default_file = { + use etcetera::BaseStrategy; + + etcetera::base_strategy::Windows::new() + .ok() + .map(|basedirs| basedirs.data_dir().join("postgres").join("pgpass.conf")) + }; + load_password_from_file(default_file?, host, port, username, database) +} + +/// try to extract a password from a pgpass file +fn load_password_from_file( + path: PathBuf, + host: &str, + port: u16, + username: &str, + database: Option<&str>, +) -> Option { + let file = File::open(&path) + .map_err(|e| { + tracing::warn!( + path = %path.display(), + "Failed to open `.pgpass` file: {e:?}", + ); + }) + .ok()?; + + #[cfg(target_os = "linux")] + { + use std::os::unix::fs::PermissionsExt; + + // check file permissions on linux + + let metadata = file.metadata().ok()?; + let permissions = metadata.permissions(); + let mode = permissions.mode(); + if mode & 0o77 != 0 { + tracing::warn!( + path = %path.display(), + permissions = format!("{mode:o}"), + "Ignoring path. Permissions are not strict enough", + ); + return None; + } + } + + let reader = BufReader::new(file); + load_password_from_reader(reader, host, port, username, database) +} + +fn load_password_from_reader( + mut reader: impl BufRead, + host: &str, + port: u16, + username: &str, + database: Option<&str>, +) -> Option { + let mut line = String::new(); + + // https://stackoverflow.com/a/55041833 + fn trim_newline(s: &mut String) { + if s.ends_with('\n') { + s.pop(); + if s.ends_with('\r') { + s.pop(); + } + } + } + + while let Ok(n) = reader.read_line(&mut line) { + if n == 0 { + break; + } + + if line.starts_with('#') { + // comment, do nothing + } else { + // try to load password from line + trim_newline(&mut line); + if let Some(password) = load_password_from_line(&line, host, port, username, database) { + return Some(password); + } + } + + line.clear(); + } + + None +} + +/// try to check all fields & extract the password +fn load_password_from_line( + mut line: &str, + host: &str, + port: u16, + username: &str, + database: Option<&str>, +) -> Option { + let whole_line = line; + + // Pgpass line ordering: hostname, port, database, username, password + // See: https://www.postgresql.org/docs/9.3/libpq-pgpass.html + match line.trim_start().chars().next() { + None | Some('#') => None, + _ => { + matches_next_field(whole_line, &mut line, host)?; + matches_next_field(whole_line, &mut line, &port.to_string())?; + matches_next_field(whole_line, &mut line, database.unwrap_or_default())?; + matches_next_field(whole_line, &mut line, username)?; + Some(line.to_owned()) + } + } +} + +/// check if the next field matches the provided value +fn matches_next_field(whole_line: &str, line: &mut &str, value: &str) -> Option<()> { + let field = find_next_field(line); + match field { + Some(field) => { + if field == "*" || field == value { + Some(()) + } else { + None + } + } + None => { + tracing::warn!(line = whole_line, "Malformed line in pgpass file"); + None + } + } +} + +/// extract the next value from a line in a pgpass file +/// +/// `line` will get updated to point behind the field and delimiter +fn find_next_field<'a>(line: &mut &'a str) -> Option> { + let mut escaping = false; + let mut escaped_string = None; + let mut last_added = 0; + + let char_indicies = line.char_indices(); + for (idx, c) in char_indicies { + if c == ':' && !escaping { + let (field, rest) = line.split_at(idx); + *line = &rest[1..]; + + if let Some(mut escaped_string) = escaped_string { + escaped_string += &field[last_added..]; + return Some(Cow::Owned(escaped_string)); + } else { + return Some(Cow::Borrowed(field)); + } + } else if c == '\\' { + let s = escaped_string.get_or_insert_with(String::new); + + if escaping { + s.push('\\'); + } else { + *s += &line[last_added..idx]; + } + + escaping = !escaping; + last_added = idx + 1; + } else { + escaping = false; + } + } + + None +} + +#[cfg(test)] +mod tests { + use super::{find_next_field, load_password_from_line, load_password_from_reader}; + use std::borrow::Cow; + + #[test] + fn test_find_next_field() { + fn test_case<'a>(mut input: &'a str, result: Option>, rest: &str) { + assert_eq!(find_next_field(&mut input), result); + assert_eq!(input, rest); + } + + // normal field + test_case("foo:bar:baz", Some(Cow::Borrowed("foo")), "bar:baz"); + // \ escaped + test_case( + "foo\\\\:bar:baz", + Some(Cow::Owned("foo\\".to_owned())), + "bar:baz", + ); + // : escaped + test_case( + "foo\\::bar:baz", + Some(Cow::Owned("foo:".to_owned())), + "bar:baz", + ); + // unnecessary escape + test_case( + "foo\\a:bar:baz", + Some(Cow::Owned("fooa".to_owned())), + "bar:baz", + ); + // other text after escape + test_case( + "foo\\\\a:bar:baz", + Some(Cow::Owned("foo\\a".to_owned())), + "bar:baz", + ); + // double escape + test_case( + "foo\\\\\\\\a:bar:baz", + Some(Cow::Owned("foo\\\\a".to_owned())), + "bar:baz", + ); + // utf8 support + test_case("🦀:bar:baz", Some(Cow::Borrowed("🦀")), "bar:baz"); + + // missing delimiter (eof) + test_case("foo", None, "foo"); + // missing delimiter after escape + test_case("foo\\:", None, "foo\\:"); + // missing delimiter after unused trailing escape + test_case("foo\\", None, "foo\\"); + } + + #[test] + fn test_load_password_from_line() { + // normal + assert_eq!( + load_password_from_line( + "localhost:5432:bar:foo:baz", + "localhost", + 5432, + "foo", + Some("bar") + ), + Some("baz".to_owned()) + ); + // wildcard + assert_eq!( + load_password_from_line("*:5432:bar:foo:baz", "localhost", 5432, "foo", Some("bar")), + Some("baz".to_owned()) + ); + // accept wildcard with missing db + assert_eq!( + load_password_from_line("localhost:5432:*:foo:baz", "localhost", 5432, "foo", None), + Some("baz".to_owned()) + ); + + // doesn't match + assert_eq!( + load_password_from_line( + "thishost:5432:bar:foo:baz", + "thathost", + 5432, + "foo", + Some("bar") + ), + None + ); + // malformed entry + assert_eq!( + load_password_from_line( + "localhost:5432:bar:foo", + "localhost", + 5432, + "foo", + Some("bar") + ), + None + ); + } + + #[test] + fn test_load_password_from_reader() { + let file = b"\ + localhost:5432:bar:foo:baz\n\ + # mixed line endings (also a comment!)\n\ + *:5432:bar:foo:baz\r\n\ + # trailing space, comment with CRLF! \r\n\ + thishost:5432:bar:foo:baz \n\ + # malformed line \n\ + thathost:5432:foobar:foo\n\ + # missing trailing newline\n\ + localhost:5432:*:foo:baz + "; + + // normal + assert_eq!( + load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", Some("bar")), + Some("baz".to_owned()) + ); + // wildcard + assert_eq!( + load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", Some("foobar")), + Some("baz".to_owned()) + ); + // accept wildcard with missing db + assert_eq!( + load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", None), + Some("baz".to_owned()) + ); + + // doesn't match + assert_eq!( + load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", Some("foobar")), + None + ); + // malformed entry + assert_eq!( + load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", Some("foobar")), + None + ); + } +} diff --git a/patches/sqlx-postgres/src/options/ssl_mode.rs b/patches/sqlx-postgres/src/options/ssl_mode.rs new file mode 100644 index 000000000..657728ab0 --- /dev/null +++ b/patches/sqlx-postgres/src/options/ssl_mode.rs @@ -0,0 +1,53 @@ +use crate::error::Error; +use std::str::FromStr; + +/// Options for controlling the level of protection provided for PostgreSQL SSL connections. +/// +/// It is used by the [`ssl_mode`](super::PgConnectOptions::ssl_mode) method. +#[derive(Debug, Clone, Copy, Default)] +pub enum PgSslMode { + /// Only try a non-SSL connection. + Disable, + + /// First try a non-SSL connection; if that fails, try an SSL connection. + Allow, + + /// First try an SSL connection; if that fails, try a non-SSL connection. + /// + /// This is the default if no other mode is specified. + #[default] + Prefer, + + /// Only try an SSL connection. If a root CA file is present, verify the connection + /// in the same way as if `VerifyCa` was specified. + Require, + + /// Only try an SSL connection, and verify that the server certificate is issued by a + /// trusted certificate authority (CA). + VerifyCa, + + /// Only try an SSL connection; verify that the server certificate is issued by a trusted + /// CA and that the requested server host name matches that in the certificate. + VerifyFull, +} + +impl FromStr for PgSslMode { + type Err = Error; + + fn from_str(s: &str) -> Result { + Ok(match &*s.to_ascii_lowercase() { + "disable" => PgSslMode::Disable, + "allow" => PgSslMode::Allow, + "prefer" => PgSslMode::Prefer, + "require" => PgSslMode::Require, + "verify-ca" => PgSslMode::VerifyCa, + "verify-full" => PgSslMode::VerifyFull, + + _ => { + return Err(Error::Configuration( + format!("unknown value {s:?} for `ssl_mode`").into(), + )); + } + }) + } +} diff --git a/patches/sqlx-postgres/src/query_result.rs b/patches/sqlx-postgres/src/query_result.rs new file mode 100644 index 000000000..870c1aff6 --- /dev/null +++ b/patches/sqlx-postgres/src/query_result.rs @@ -0,0 +1,30 @@ +use std::iter::{Extend, IntoIterator}; + +#[derive(Debug, Default)] +pub struct PgQueryResult { + pub(super) rows_affected: u64, +} + +impl PgQueryResult { + pub fn rows_affected(&self) -> u64 { + self.rows_affected + } +} + +impl Extend for PgQueryResult { + fn extend>(&mut self, iter: T) { + for elem in iter { + self.rows_affected += elem.rows_affected; + } + } +} + +#[cfg(feature = "any")] +impl From for crate::any::AnyQueryResult { + fn from(done: PgQueryResult) -> Self { + crate::any::AnyQueryResult { + rows_affected: done.rows_affected, + last_insert_id: None, + } + } +} diff --git a/patches/sqlx-postgres/src/row.rs b/patches/sqlx-postgres/src/row.rs new file mode 100644 index 000000000..f9e43bb9c --- /dev/null +++ b/patches/sqlx-postgres/src/row.rs @@ -0,0 +1,75 @@ +use crate::column::ColumnIndex; +use crate::error::Error; +use crate::message::DataRow; +use crate::statement::PgStatementMetadata; +use crate::value::PgValueFormat; +use crate::{PgColumn, PgValueRef, Postgres}; +pub(crate) use sqlx_core::row::Row; +use sqlx_core::type_checking::TypeChecking; +use sqlx_core::value::ValueRef; +use std::fmt::Debug; +use std::sync::Arc; + +/// Implementation of [`Row`] for PostgreSQL. +pub struct PgRow { + pub(crate) data: DataRow, + pub(crate) format: PgValueFormat, + pub(crate) metadata: Arc, +} + +impl Row for PgRow { + type Database = Postgres; + + fn columns(&self) -> &[PgColumn] { + &self.metadata.columns + } + + fn try_get_raw(&self, index: I) -> Result, Error> + where + I: ColumnIndex, + { + let index = index.index(self)?; + let column = &self.metadata.columns[index]; + let value = self.data.get(index); + + Ok(PgValueRef { + format: self.format, + row: Some(&self.data.storage), + type_info: column.type_info.clone(), + value, + }) + } +} + +impl ColumnIndex for &'_ str { + fn index(&self, row: &PgRow) -> Result { + row.metadata + .column_names + .get(*self) + .ok_or_else(|| Error::ColumnNotFound((*self).into())) + .copied() + } +} + +impl Debug for PgRow { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "PgRow ")?; + + let mut debug_map = f.debug_map(); + for (index, column) in self.columns().iter().enumerate() { + match self.try_get_raw(index) { + Ok(value) => { + debug_map.entry( + &column.name, + &Postgres::fmt_value_debug(&::to_owned(&value)), + ); + } + Err(error) => { + debug_map.entry(&column.name, &format!("decode error: {error:?}")); + } + } + } + + debug_map.finish() + } +} diff --git a/patches/sqlx-postgres/src/statement.rs b/patches/sqlx-postgres/src/statement.rs new file mode 100644 index 000000000..abd553af3 --- /dev/null +++ b/patches/sqlx-postgres/src/statement.rs @@ -0,0 +1,86 @@ +use super::{PgColumn, PgTypeInfo}; +use crate::column::ColumnIndex; +use crate::error::Error; +use crate::ext::ustr::UStr; +use crate::{PgArguments, Postgres}; +use std::borrow::Cow; +use std::sync::Arc; + +pub(crate) use sqlx_core::statement::Statement; +use sqlx_core::{Either, HashMap}; + +#[derive(Debug, Clone)] +pub struct PgStatement<'q> { + pub(crate) sql: Cow<'q, str>, + pub(crate) metadata: Arc, +} + +#[derive(Debug, Default)] +pub(crate) struct PgStatementMetadata { + pub(crate) columns: Vec, + // This `Arc` is not redundant; it's used to avoid deep-copying this map for the `Any` backend. + // See `sqlx-postgres/src/any.rs` + pub(crate) column_names: Arc>, + pub(crate) parameters: Vec, +} + +impl<'q> Statement<'q> for PgStatement<'q> { + type Database = Postgres; + + fn to_owned(&self) -> PgStatement<'static> { + PgStatement::<'static> { + sql: Cow::Owned(self.sql.clone().into_owned()), + metadata: self.metadata.clone(), + } + } + + fn sql(&self) -> &str { + &self.sql + } + + fn parameters(&self) -> Option> { + Some(Either::Left(&self.metadata.parameters)) + } + + fn columns(&self) -> &[PgColumn] { + &self.metadata.columns + } + + impl_statement_query!(PgArguments); +} + +impl ColumnIndex> for &'_ str { + fn index(&self, statement: &PgStatement<'_>) -> Result { + statement + .metadata + .column_names + .get(*self) + .ok_or_else(|| Error::ColumnNotFound((*self).into())) + .copied() + } +} + +// #[cfg(feature = "any")] +// impl<'q> From> for crate::any::AnyStatement<'q> { +// #[inline] +// fn from(statement: PgStatement<'q>) -> Self { +// crate::any::AnyStatement::<'q> { +// columns: statement +// .metadata +// .columns +// .iter() +// .map(|col| col.clone().into()) +// .collect(), +// column_names: statement.metadata.column_names.clone(), +// parameters: Some(Either::Left( +// statement +// .metadata +// .parameters +// .iter() +// .map(|ty| ty.clone().into()) +// .collect(), +// )), +// sql: statement.sql, +// } +// } +// } diff --git a/patches/sqlx-postgres/src/testing/mod.rs b/patches/sqlx-postgres/src/testing/mod.rs new file mode 100644 index 000000000..fb36ab413 --- /dev/null +++ b/patches/sqlx-postgres/src/testing/mod.rs @@ -0,0 +1,225 @@ +use std::fmt::Write; +use std::ops::Deref; +use std::str::FromStr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::{Duration, SystemTime}; + +use futures_core::future::BoxFuture; + +use once_cell::sync::OnceCell; + +use crate::connection::Connection; + +use crate::error::Error; +use crate::executor::Executor; +use crate::pool::{Pool, PoolOptions}; +use crate::query::query; +use crate::query_scalar::query_scalar; +use crate::{PgConnectOptions, PgConnection, Postgres}; + +pub(crate) use sqlx_core::testing::*; + +// Using a blocking `OnceCell` here because the critical sections are short. +static MASTER_POOL: OnceCell> = OnceCell::new(); +// Automatically delete any databases created before the start of the test binary. +static DO_CLEANUP: AtomicBool = AtomicBool::new(true); + +impl TestSupport for Postgres { + fn test_context(args: &TestArgs) -> BoxFuture<'_, Result, Error>> { + Box::pin(async move { test_context(args).await }) + } + + fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + let mut conn = MASTER_POOL + .get() + .expect("cleanup_test() invoked outside `#[sqlx::test]") + .acquire() + .await?; + + conn.execute(&format!("drop database if exists {db_name:?};")[..]) + .await?; + + query("delete from _sqlx_test.databases where db_name = $1") + .bind(db_name) + .execute(&mut *conn) + .await?; + + Ok(()) + }) + } + + fn cleanup_test_dbs() -> BoxFuture<'static, Result, Error>> { + Box::pin(async move { + let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + let mut conn = PgConnection::connect(&url).await?; + + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap(); + + let num_deleted = do_cleanup(&mut conn, now).await?; + let _ = conn.close().await; + Ok(Some(num_deleted)) + }) + } + + fn snapshot( + _conn: &mut Self::Connection, + ) -> BoxFuture<'_, Result, Error>> { + // TODO: I want to get the testing feature out the door so this will have to wait, + // but I'm keeping the code around for now because I plan to come back to it. + todo!() + } +} + +async fn test_context(args: &TestArgs) -> Result, Error> { + let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + let master_opts = PgConnectOptions::from_str(&url).expect("failed to parse DATABASE_URL"); + + let pool = PoolOptions::new() + // Postgres' normal connection limit is 100 plus 3 superuser connections + // We don't want to use the whole cap and there may be fuzziness here due to + // concurrently running tests anyway. + .max_connections(20) + // Immediately close master connections. Tokio's I/O streams don't like hopping runtimes. + .after_release(|_conn, _| Box::pin(async move { Ok(false) })) + .connect_lazy_with(master_opts); + + let master_pool = match MASTER_POOL.try_insert(pool) { + Ok(inserted) => inserted, + Err((existing, pool)) => { + // Sanity checks. + assert_eq!( + existing.connect_options().host, + pool.connect_options().host, + "DATABASE_URL changed at runtime, host differs" + ); + + assert_eq!( + existing.connect_options().database, + pool.connect_options().database, + "DATABASE_URL changed at runtime, database differs" + ); + + existing + } + }; + + let mut conn = master_pool.acquire().await?; + + // language=PostgreSQL + conn.execute( + // Explicit lock avoids this latent bug: https://stackoverflow.com/a/29908840 + // I couldn't find a bug on the mailing list for `CREATE SCHEMA` specifically, + // but a clearly related bug with `CREATE TABLE` has been known since 2007: + // https://www.postgresql.org/message-id/200710222037.l9MKbCJZ098744%40wwwmaster.postgresql.org + r#" + lock table pg_catalog.pg_namespace in share row exclusive mode; + + create schema if not exists _sqlx_test; + + create table if not exists _sqlx_test.databases ( + db_name text primary key, + test_path text not null, + created_at timestamptz not null default now() + ); + + create index if not exists databases_created_at + on _sqlx_test.databases(created_at); + + create sequence if not exists _sqlx_test.database_ids; + "#, + ) + .await?; + + // Record the current time _before_ we acquire the `DO_CLEANUP` permit. This + // prevents the first test thread from accidentally deleting new test dbs + // created by other test threads if we're a bit slow. + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap(); + + // Only run cleanup if the test binary just started. + if DO_CLEANUP.swap(false, Ordering::SeqCst) { + do_cleanup(&mut conn, now).await?; + } + + let new_db_name: String = query_scalar( + r#" + insert into _sqlx_test.databases(db_name, test_path) + select '_sqlx_test_' || nextval('_sqlx_test.database_ids'), $1 + returning db_name + "#, + ) + .bind(args.test_path) + .fetch_one(&mut *conn) + .await?; + + conn.execute(&format!("create database {new_db_name:?}")[..]) + .await?; + + Ok(TestContext { + pool_opts: PoolOptions::new() + // Don't allow a single test to take all the connections. + // Most tests shouldn't require more than 5 connections concurrently, + // or else they're likely doing too much in one test. + .max_connections(5) + // Close connections ASAP if left in the idle queue. + .idle_timeout(Some(Duration::from_secs(1))) + .parent(master_pool.clone()), + connect_opts: master_pool + .connect_options() + .deref() + .clone() + .database(&new_db_name), + db_name: new_db_name, + }) +} + +async fn do_cleanup(conn: &mut PgConnection, created_before: Duration) -> Result { + // since SystemTime is not monotonic we added a little margin here to avoid race conditions with other threads + let created_before = i64::try_from(created_before.as_secs()).unwrap() - 2; + + let delete_db_names: Vec = query_scalar( + "select db_name from _sqlx_test.databases \ + where created_at < (to_timestamp($1) at time zone 'UTC')", + ) + .bind(created_before) + .fetch_all(&mut *conn) + .await?; + + if delete_db_names.is_empty() { + return Ok(0); + } + + let mut deleted_db_names = Vec::with_capacity(delete_db_names.len()); + let delete_db_names = delete_db_names.into_iter(); + + let mut command = String::new(); + + for db_name in delete_db_names { + command.clear(); + writeln!(command, "drop database if exists {db_name:?};").ok(); + match conn.execute(&*command).await { + Ok(_deleted) => { + deleted_db_names.push(db_name); + } + // Assume a database error just means the DB is still in use. + Err(Error::Database(dbe)) => { + eprintln!("could not clean test database {db_name:?}: {dbe}") + } + // Bubble up other errors + Err(e) => return Err(e), + } + } + + query("delete from _sqlx_test.databases where db_name = any($1::text[])") + .bind(&deleted_db_names) + .execute(&mut *conn) + .await?; + + Ok(deleted_db_names.len()) +} diff --git a/patches/sqlx-postgres/src/transaction.rs b/patches/sqlx-postgres/src/transaction.rs new file mode 100644 index 000000000..b9330d529 --- /dev/null +++ b/patches/sqlx-postgres/src/transaction.rs @@ -0,0 +1,88 @@ +use futures_core::future::BoxFuture; + +use crate::error::Error; +use crate::executor::Executor; + +use crate::{PgConnection, Postgres}; + +pub(crate) use sqlx_core::transaction::*; + +/// Implementation of [`TransactionManager`] for PostgreSQL. +pub struct PgTransactionManager; + +impl TransactionManager for PgTransactionManager { + type Database = Postgres; + + fn begin(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + let rollback = Rollback::new(conn); + let query = begin_ansi_transaction_sql(rollback.conn.transaction_depth); + rollback.conn.queue_simple_query(&query)?; + rollback.conn.transaction_depth += 1; + rollback.conn.wait_until_ready().await?; + rollback.defuse(); + + Ok(()) + }) + } + + fn commit(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + if conn.transaction_depth > 0 { + conn.execute(&*commit_ansi_transaction_sql(conn.transaction_depth)) + .await?; + + conn.transaction_depth -= 1; + } + + Ok(()) + }) + } + + fn rollback(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + if conn.transaction_depth > 0 { + conn.execute(&*rollback_ansi_transaction_sql(conn.transaction_depth)) + .await?; + + conn.transaction_depth -= 1; + } + + Ok(()) + }) + } + + fn start_rollback(conn: &mut PgConnection) { + if conn.transaction_depth > 0 { + conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.transaction_depth)) + .expect("BUG: Rollback query somehow too large for protocol"); + + conn.transaction_depth -= 1; + } + } +} + +struct Rollback<'c> { + conn: &'c mut PgConnection, + defuse: bool, +} + +impl Drop for Rollback<'_> { + fn drop(&mut self) { + if !self.defuse { + PgTransactionManager::start_rollback(self.conn) + } + } +} + +impl<'c> Rollback<'c> { + fn new(conn: &'c mut PgConnection) -> Self { + Self { + conn, + defuse: false, + } + } + fn defuse(mut self) { + self.defuse = true; + } +} diff --git a/patches/sqlx-postgres/src/type_checking.rs b/patches/sqlx-postgres/src/type_checking.rs new file mode 100644 index 000000000..e22d3b900 --- /dev/null +++ b/patches/sqlx-postgres/src/type_checking.rs @@ -0,0 +1,215 @@ +use crate::Postgres; + +// The paths used below will also be emitted by the macros so they have to match the final facade. +#[allow(unused_imports, dead_code)] +mod sqlx { + pub use crate as postgres; + pub use sqlx_core::*; +} + +impl_type_checking!( + Postgres { + (), + bool, + String | &str, + i8, + i16, + i32, + i64, + f32, + f64, + Vec | &[u8], + + sqlx::postgres::types::Oid, + + sqlx::postgres::types::PgInterval, + + sqlx::postgres::types::PgMoney, + + sqlx::postgres::types::PgLTree, + + sqlx::postgres::types::PgLQuery, + + sqlx::postgres::types::PgCube, + + #[cfg(feature = "uuid")] + sqlx::types::Uuid, + + #[cfg(all(feature = "chrono", not(feature = "time")))] + sqlx::types::chrono::NaiveTime, + + #[cfg(all(feature = "chrono", not(feature = "time")))] + sqlx::types::chrono::NaiveDate, + + #[cfg(all(feature = "chrono", not(feature = "time")))] + sqlx::types::chrono::NaiveDateTime, + + #[cfg(all(feature = "chrono", not(feature = "time")))] + sqlx::types::chrono::DateTime | sqlx::types::chrono::DateTime<_>, + + #[cfg(all(feature = "chrono", not(feature = "time")))] + sqlx::postgres::types::PgTimeTz, + + #[cfg(feature = "time")] + sqlx::types::time::Time, + + #[cfg(feature = "time")] + sqlx::types::time::Date, + + #[cfg(feature = "time")] + sqlx::types::time::PrimitiveDateTime, + + #[cfg(feature = "time")] + sqlx::types::time::OffsetDateTime, + + #[cfg(feature = "time")] + sqlx::postgres::types::PgTimeTz, + + #[cfg(feature = "bigdecimal")] + sqlx::types::BigDecimal, + + #[cfg(feature = "rust_decimal")] + sqlx::types::Decimal, + + #[cfg(feature = "ipnetwork")] + sqlx::types::ipnetwork::IpNetwork, + + #[cfg(feature = "mac_address")] + sqlx::types::mac_address::MacAddress, + + #[cfg(feature = "json")] + sqlx::types::JsonValue, + + #[cfg(feature = "bit-vec")] + sqlx::types::BitVec, + + // Arrays + + Vec | &[bool], + Vec | &[String], + Vec> | &[Vec], + Vec | &[i8], + Vec | &[i16], + Vec | &[i32], + Vec | &[i64], + Vec | &[f32], + Vec | &[f64], + Vec | &[sqlx::postgres::types::Oid], + Vec | &[sqlx::postgres::types::PgMoney], + + #[cfg(feature = "uuid")] + Vec | &[sqlx::types::Uuid], + + #[cfg(all(feature = "chrono", not(feature = "time")))] + Vec | &[sqlx::types::chrono::NaiveTime], + + #[cfg(all(feature = "chrono", not(feature = "time")))] + Vec | &[sqlx::types::chrono::NaiveDate], + + #[cfg(all(feature = "chrono", not(feature = "time")))] + Vec | &[sqlx::types::chrono::NaiveDateTime], + + #[cfg(all(feature = "chrono", not(feature = "time")))] + Vec> | &[sqlx::types::chrono::DateTime<_>], + + #[cfg(feature = "time")] + Vec | &[sqlx::types::time::Time], + + #[cfg(feature = "time")] + Vec | &[sqlx::types::time::Date], + + #[cfg(feature = "time")] + Vec | &[sqlx::types::time::PrimitiveDateTime], + + #[cfg(feature = "time")] + Vec | &[sqlx::types::time::OffsetDateTime], + + #[cfg(feature = "bigdecimal")] + Vec | &[sqlx::types::BigDecimal], + + #[cfg(feature = "rust_decimal")] + Vec | &[sqlx::types::Decimal], + + #[cfg(feature = "ipnetwork")] + Vec | &[sqlx::types::ipnetwork::IpNetwork], + + #[cfg(feature = "mac_address")] + Vec | &[sqlx::types::mac_address::MacAddress], + + #[cfg(feature = "json")] + Vec | &[sqlx::types::JsonValue], + + // Ranges + + sqlx::postgres::types::PgRange, + sqlx::postgres::types::PgRange, + + #[cfg(feature = "bigdecimal")] + sqlx::postgres::types::PgRange, + + #[cfg(feature = "rust_decimal")] + sqlx::postgres::types::PgRange, + + #[cfg(all(feature = "chrono", not(feature = "time")))] + sqlx::postgres::types::PgRange, + + #[cfg(all(feature = "chrono", not(feature = "time")))] + sqlx::postgres::types::PgRange, + + #[cfg(all(feature = "chrono", not(feature = "time")))] + sqlx::postgres::types::PgRange> | + sqlx::postgres::types::PgRange>, + + #[cfg(feature = "time")] + sqlx::postgres::types::PgRange, + + #[cfg(feature = "time")] + sqlx::postgres::types::PgRange, + + #[cfg(feature = "time")] + sqlx::postgres::types::PgRange, + + // Range arrays + + Vec> | &[sqlx::postgres::types::PgRange], + Vec> | &[sqlx::postgres::types::PgRange], + + #[cfg(feature = "bigdecimal")] + Vec> | + &[sqlx::postgres::types::PgRange], + + #[cfg(feature = "rust_decimal")] + Vec> | + &[sqlx::postgres::types::PgRange], + + #[cfg(all(feature = "chrono", not(feature = "time")))] + Vec> | + &[sqlx::postgres::types::PgRange], + + #[cfg(all(feature = "chrono", not(feature = "time")))] + Vec> | + &[sqlx::postgres::types::PgRange], + + #[cfg(all(feature = "chrono", not(feature = "time")))] + Vec>> | + &[sqlx::postgres::types::PgRange>], + + #[cfg(all(feature = "chrono", not(feature = "time")))] + Vec>> | + &[sqlx::postgres::types::PgRange>], + + #[cfg(feature = "time")] + Vec> | + &[sqlx::postgres::types::PgRange], + + #[cfg(feature = "time")] + Vec> | + &[sqlx::postgres::types::PgRange], + + #[cfg(feature = "time")] + Vec> | + &[sqlx::postgres::types::PgRange], + }, + ParamChecking::Strong, + feature-types: info => info.__type_feature_gate(), +); diff --git a/patches/sqlx-postgres/src/type_info.rs b/patches/sqlx-postgres/src/type_info.rs new file mode 100644 index 000000000..3d948f73d --- /dev/null +++ b/patches/sqlx-postgres/src/type_info.rs @@ -0,0 +1,1390 @@ +#![allow(dead_code)] + +use std::borrow::Cow; +use std::fmt::{self, Display, Formatter}; +use std::ops::Deref; +use std::sync::Arc; + +use crate::ext::ustr::UStr; +use crate::types::Oid; + +pub(crate) use sqlx_core::type_info::TypeInfo; + +/// Type information for a PostgreSQL type. +/// +/// ### Note: Implementation of `==` ([`PartialEq::eq()`]) +/// Because `==` on [`TypeInfo`]s has been used throughout the SQLx API as a synonym for type compatibility, +/// e.g. in the default impl of [`Type::compatible()`][sqlx_core::types::Type::compatible], +/// some concessions have been made in the implementation. +/// +/// When comparing two `PgTypeInfo`s using the `==` operator ([`PartialEq::eq()`]), +/// if one was constructed with [`Self::with_oid()`] and the other with [`Self::with_name()`] or +/// [`Self::array_of()`], `==` will return `true`: +/// +/// ``` +/// # use sqlx::postgres::{types::Oid, PgTypeInfo}; +/// // Potentially surprising result, this assert will pass: +/// assert_eq!(PgTypeInfo::with_oid(Oid(1)), PgTypeInfo::with_name("definitely_not_real")); +/// ``` +/// +/// Since it is not possible in this case to prove the types are _not_ compatible (because +/// both `PgTypeInfo`s need to be resolved by an active connection to know for sure) +/// and type compatibility is mainly done as a sanity check anyway, +/// it was deemed acceptable to fudge equality in this very specific case. +/// +/// This also applies when querying with the text protocol (not using prepared statements, +/// e.g. [`sqlx::raw_sql()`][sqlx_core::raw_sql::raw_sql]), as the connection will be unable +/// to look up the type info like it normally does when preparing a statement: it won't know +/// what the OIDs of the output columns will be until it's in the middle of reading the result, +/// and by that time it's too late. +/// +/// To compare types for exact equality, use [`Self::type_eq()`] instead. +#[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct PgTypeInfo(pub(crate) PgType); + +impl Deref for PgTypeInfo { + type Target = PgType; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +#[repr(u32)] +pub enum PgType { + Bool, + Bytea, + Char, + Name, + Int8, + Int2, + Int4, + Text, + Oid, + Json, + JsonArray, + Point, + Lseg, + Path, + Box, + Polygon, + Line, + LineArray, + Cidr, + CidrArray, + Float4, + Float8, + Unknown, + Circle, + CircleArray, + Macaddr8, + Macaddr8Array, + Macaddr, + Inet, + BoolArray, + ByteaArray, + CharArray, + NameArray, + Int2Array, + Int4Array, + TextArray, + BpcharArray, + VarcharArray, + Int8Array, + PointArray, + LsegArray, + PathArray, + BoxArray, + Float4Array, + Float8Array, + PolygonArray, + OidArray, + MacaddrArray, + InetArray, + Bpchar, + Varchar, + Date, + Time, + Timestamp, + TimestampArray, + DateArray, + TimeArray, + Timestamptz, + TimestamptzArray, + Interval, + IntervalArray, + NumericArray, + Timetz, + TimetzArray, + Bit, + BitArray, + Varbit, + VarbitArray, + Numeric, + Record, + RecordArray, + Uuid, + UuidArray, + Jsonb, + JsonbArray, + Int4Range, + Int4RangeArray, + NumRange, + NumRangeArray, + TsRange, + TsRangeArray, + TstzRange, + TstzRangeArray, + DateRange, + DateRangeArray, + Int8Range, + Int8RangeArray, + Jsonpath, + JsonpathArray, + Money, + MoneyArray, + + // https://www.postgresql.org/docs/9.3/datatype-pseudo.html + Void, + + // A realized user-defined type. When a connection sees a DeclareXX variant it resolves + // into this one before passing it along to `accepts` or inside of `Value` objects. + Custom(Arc), + + // From [`PgTypeInfo::with_name`] + DeclareWithName(UStr), + + // NOTE: Do we want to bring back type declaration by ID? It's notoriously fragile but + // someone may have a user for it + DeclareWithOid(Oid), + + DeclareArrayOf(Arc), +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct PgCustomType { + #[cfg_attr(feature = "offline", serde(skip))] + pub(crate) oid: Oid, + pub(crate) name: UStr, + pub(crate) kind: PgTypeKind, +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub enum PgTypeKind { + Simple, + Pseudo, + Domain(PgTypeInfo), + Composite(Arc<[(String, PgTypeInfo)]>), + Array(PgTypeInfo), + Enum(Arc<[String]>), + Range(PgTypeInfo), +} + +#[derive(Debug)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct PgArrayOf { + pub(crate) elem_name: UStr, + pub(crate) name: Box, +} + +impl PgTypeInfo { + /// Returns the corresponding `PgTypeInfo` if the OID is a built-in type and recognized by SQLx. + pub(crate) fn try_from_oid(oid: Oid) -> Option { + PgType::try_from_oid(oid).map(Self) + } + + /// Returns the _kind_ (simple, array, enum, etc.) for this type. + pub fn kind(&self) -> &PgTypeKind { + self.0.kind() + } + + /// Returns the OID for this type, if available. + /// + /// The OID may not be available if SQLx only knows the type by name. + /// It will have to be resolved by a `PgConnection` at runtime which + /// will yield a new and semantically distinct `TypeInfo` instance. + /// + /// This method does not perform any such lookup. + /// + /// ### Note + /// With the exception of [the default `pg_type` catalog][pg_type], type OIDs are *not* stable in PostgreSQL. + /// If a type is added by an extension, its OID will be assigned when the `CREATE EXTENSION` statement is executed, + /// and so can change depending on what extensions are installed and in what order, as well as the exact + /// version of PostgreSQL. + /// + /// [pg_type]: https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_type.dat + pub fn oid(&self) -> Option { + self.0.try_oid() + } + + #[doc(hidden)] + pub fn __type_feature_gate(&self) -> Option<&'static str> { + if [ + PgTypeInfo::DATE, + PgTypeInfo::TIME, + PgTypeInfo::TIMESTAMP, + PgTypeInfo::TIMESTAMPTZ, + PgTypeInfo::DATE_ARRAY, + PgTypeInfo::TIME_ARRAY, + PgTypeInfo::TIMESTAMP_ARRAY, + PgTypeInfo::TIMESTAMPTZ_ARRAY, + ] + .contains(self) + { + Some("time") + } else if [PgTypeInfo::UUID, PgTypeInfo::UUID_ARRAY].contains(self) { + Some("uuid") + } else if [ + PgTypeInfo::JSON, + PgTypeInfo::JSONB, + PgTypeInfo::JSON_ARRAY, + PgTypeInfo::JSONB_ARRAY, + ] + .contains(self) + { + Some("json") + } else if [ + PgTypeInfo::CIDR, + PgTypeInfo::INET, + PgTypeInfo::CIDR_ARRAY, + PgTypeInfo::INET_ARRAY, + ] + .contains(self) + { + Some("ipnetwork") + } else if [PgTypeInfo::MACADDR].contains(self) { + Some("mac_address") + } else if [PgTypeInfo::NUMERIC, PgTypeInfo::NUMERIC_ARRAY].contains(self) { + Some("bigdecimal") + } else { + None + } + } + + /// Create a `PgTypeInfo` from a type name. + /// + /// The OID for the type will be fetched from Postgres on use of + /// a value of this type. The fetched OID will be cached per-connection. + /// + /// ### Note: Type Names Prefixed with `_` + /// In `pg_catalog.pg_type`, Postgres prefixes a type name with `_` to denote an array of that + /// type, e.g. `int4[]` actually exists in `pg_type` as `_int4`. + /// + /// Previously, it was necessary in manual [`PgHasArrayType`][crate::PgHasArrayType] impls + /// to return [`PgTypeInfo::with_name()`] with the type name prefixed with `_` to denote + /// an array type, but this would not work with schema-qualified names. + /// + /// As of 0.8, [`PgTypeInfo::array_of()`] is used to declare an array type, + /// and the Postgres driver is now able to properly resolve arrays of custom types, + /// even in other schemas, which was not previously supported. + /// + /// It is highly recommended to migrate existing usages to [`PgTypeInfo::array_of()`] where + /// applicable. + /// + /// However, to maintain compatibility, the driver now infers any type name prefixed with `_` + /// to be an array of that type. This may introduce some breakages for types which use + /// a `_` prefix but which are not arrays. + /// + /// As a workaround, type names with `_` as a prefix but which are not arrays should be wrapped + /// in quotes, e.g.: + /// ``` + /// use sqlx::postgres::PgTypeInfo; + /// use sqlx::{Type, TypeInfo}; + /// + /// /// `CREATE TYPE "_foo" AS ENUM ('Bar', 'Baz');` + /// #[derive(sqlx::Type)] + /// // Will prevent SQLx from inferring `_foo` as an array type. + /// #[sqlx(type_name = r#""_foo""#)] + /// enum Foo { + /// Bar, + /// Baz + /// } + /// + /// assert_eq!(Foo::type_info().name(), r#""_foo""#); + /// ``` + pub const fn with_name(name: &'static str) -> Self { + Self(PgType::DeclareWithName(UStr::Static(name))) + } + + /// Create a `PgTypeInfo` of an array from the name of its element type. + /// + /// The array type OID will be fetched from Postgres on use of a value of this type. + /// The fetched OID will be cached per-connection. + pub fn array_of(elem_name: &'static str) -> Self { + // to satisfy `name()` and `display_name()`, we need to construct strings to return + Self(PgType::DeclareArrayOf(Arc::new(PgArrayOf { + elem_name: elem_name.into(), + name: format!("{elem_name}[]").into(), + }))) + } + + /// Create a `PgTypeInfo` from an OID. + /// + /// Note that the OID for a type is very dependent on the environment. If you only ever use + /// one database or if this is an unhandled built-in type, you should be fine. Otherwise, + /// you will be better served using [`Self::with_name()`]. + /// + /// ### Note: Interaction with `==` + /// This constructor may give surprising results with `==`. + /// + /// See [the type-level docs][Self] for details. + pub const fn with_oid(oid: Oid) -> Self { + Self(PgType::DeclareWithOid(oid)) + } + + /// Returns `true` if `self` can be compared exactly to `other`. + /// + /// Unlike `==`, this will return false if + pub fn type_eq(&self, other: &Self) -> bool { + self.eq_impl(other, false) + } +} + +// DEVELOPER PRO TIP: find builtin type OIDs easily by grepping this file +// https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_type.dat +// +// If you have Postgres running locally you can also try +// SELECT oid, typarray FROM pg_type where typname = '' + +impl PgType { + /// Returns the corresponding `PgType` if the OID is a built-in type and recognized by SQLx. + pub(crate) fn try_from_oid(oid: Oid) -> Option { + Some(match oid.0 { + 16 => PgType::Bool, + 17 => PgType::Bytea, + 18 => PgType::Char, + 19 => PgType::Name, + 20 => PgType::Int8, + 21 => PgType::Int2, + 23 => PgType::Int4, + 25 => PgType::Text, + 26 => PgType::Oid, + 114 => PgType::Json, + 199 => PgType::JsonArray, + 600 => PgType::Point, + 601 => PgType::Lseg, + 602 => PgType::Path, + 603 => PgType::Box, + 604 => PgType::Polygon, + 628 => PgType::Line, + 629 => PgType::LineArray, + 650 => PgType::Cidr, + 651 => PgType::CidrArray, + 700 => PgType::Float4, + 701 => PgType::Float8, + 705 => PgType::Unknown, + 718 => PgType::Circle, + 719 => PgType::CircleArray, + 774 => PgType::Macaddr8, + 775 => PgType::Macaddr8Array, + 790 => PgType::Money, + 791 => PgType::MoneyArray, + 829 => PgType::Macaddr, + 869 => PgType::Inet, + 1000 => PgType::BoolArray, + 1001 => PgType::ByteaArray, + 1002 => PgType::CharArray, + 1003 => PgType::NameArray, + 1005 => PgType::Int2Array, + 1007 => PgType::Int4Array, + 1009 => PgType::TextArray, + 1014 => PgType::BpcharArray, + 1015 => PgType::VarcharArray, + 1016 => PgType::Int8Array, + 1017 => PgType::PointArray, + 1018 => PgType::LsegArray, + 1019 => PgType::PathArray, + 1020 => PgType::BoxArray, + 1021 => PgType::Float4Array, + 1022 => PgType::Float8Array, + 1027 => PgType::PolygonArray, + 1028 => PgType::OidArray, + 1040 => PgType::MacaddrArray, + 1041 => PgType::InetArray, + 1042 => PgType::Bpchar, + 1043 => PgType::Varchar, + 1082 => PgType::Date, + 1083 => PgType::Time, + 1114 => PgType::Timestamp, + 1115 => PgType::TimestampArray, + 1182 => PgType::DateArray, + 1183 => PgType::TimeArray, + 1184 => PgType::Timestamptz, + 1185 => PgType::TimestamptzArray, + 1186 => PgType::Interval, + 1187 => PgType::IntervalArray, + 1231 => PgType::NumericArray, + 1266 => PgType::Timetz, + 1270 => PgType::TimetzArray, + 1560 => PgType::Bit, + 1561 => PgType::BitArray, + 1562 => PgType::Varbit, + 1563 => PgType::VarbitArray, + 1700 => PgType::Numeric, + 2278 => PgType::Void, + 2249 => PgType::Record, + 2287 => PgType::RecordArray, + 2950 => PgType::Uuid, + 2951 => PgType::UuidArray, + 3802 => PgType::Jsonb, + 3807 => PgType::JsonbArray, + 3904 => PgType::Int4Range, + 3905 => PgType::Int4RangeArray, + 3906 => PgType::NumRange, + 3907 => PgType::NumRangeArray, + 3908 => PgType::TsRange, + 3909 => PgType::TsRangeArray, + 3910 => PgType::TstzRange, + 3911 => PgType::TstzRangeArray, + 3912 => PgType::DateRange, + 3913 => PgType::DateRangeArray, + 3926 => PgType::Int8Range, + 3927 => PgType::Int8RangeArray, + 4072 => PgType::Jsonpath, + 4073 => PgType::JsonpathArray, + + _ => { + return None; + } + }) + } + + pub(crate) fn oid(&self) -> Oid { + match self.try_oid() { + Some(oid) => oid, + None => unreachable!("(bug) use of unresolved type declaration [oid]"), + } + } + + pub(crate) fn try_oid(&self) -> Option { + Some(match self { + PgType::Bool => Oid(16), + PgType::Bytea => Oid(17), + PgType::Char => Oid(18), + PgType::Name => Oid(19), + PgType::Int8 => Oid(20), + PgType::Int2 => Oid(21), + PgType::Int4 => Oid(23), + PgType::Text => Oid(25), + PgType::Oid => Oid(26), + PgType::Json => Oid(114), + PgType::JsonArray => Oid(199), + PgType::Point => Oid(600), + PgType::Lseg => Oid(601), + PgType::Path => Oid(602), + PgType::Box => Oid(603), + PgType::Polygon => Oid(604), + PgType::Line => Oid(628), + PgType::LineArray => Oid(629), + PgType::Cidr => Oid(650), + PgType::CidrArray => Oid(651), + PgType::Float4 => Oid(700), + PgType::Float8 => Oid(701), + PgType::Unknown => Oid(705), + PgType::Circle => Oid(718), + PgType::CircleArray => Oid(719), + PgType::Macaddr8 => Oid(774), + PgType::Macaddr8Array => Oid(775), + PgType::Money => Oid(790), + PgType::MoneyArray => Oid(791), + PgType::Macaddr => Oid(829), + PgType::Inet => Oid(869), + PgType::BoolArray => Oid(1000), + PgType::ByteaArray => Oid(1001), + PgType::CharArray => Oid(1002), + PgType::NameArray => Oid(1003), + PgType::Int2Array => Oid(1005), + PgType::Int4Array => Oid(1007), + PgType::TextArray => Oid(1009), + PgType::BpcharArray => Oid(1014), + PgType::VarcharArray => Oid(1015), + PgType::Int8Array => Oid(1016), + PgType::PointArray => Oid(1017), + PgType::LsegArray => Oid(1018), + PgType::PathArray => Oid(1019), + PgType::BoxArray => Oid(1020), + PgType::Float4Array => Oid(1021), + PgType::Float8Array => Oid(1022), + PgType::PolygonArray => Oid(1027), + PgType::OidArray => Oid(1028), + PgType::MacaddrArray => Oid(1040), + PgType::InetArray => Oid(1041), + PgType::Bpchar => Oid(1042), + PgType::Varchar => Oid(1043), + PgType::Date => Oid(1082), + PgType::Time => Oid(1083), + PgType::Timestamp => Oid(1114), + PgType::TimestampArray => Oid(1115), + PgType::DateArray => Oid(1182), + PgType::TimeArray => Oid(1183), + PgType::Timestamptz => Oid(1184), + PgType::TimestamptzArray => Oid(1185), + PgType::Interval => Oid(1186), + PgType::IntervalArray => Oid(1187), + PgType::NumericArray => Oid(1231), + PgType::Timetz => Oid(1266), + PgType::TimetzArray => Oid(1270), + PgType::Bit => Oid(1560), + PgType::BitArray => Oid(1561), + PgType::Varbit => Oid(1562), + PgType::VarbitArray => Oid(1563), + PgType::Numeric => Oid(1700), + PgType::Void => Oid(2278), + PgType::Record => Oid(2249), + PgType::RecordArray => Oid(2287), + PgType::Uuid => Oid(2950), + PgType::UuidArray => Oid(2951), + PgType::Jsonb => Oid(3802), + PgType::JsonbArray => Oid(3807), + PgType::Int4Range => Oid(3904), + PgType::Int4RangeArray => Oid(3905), + PgType::NumRange => Oid(3906), + PgType::NumRangeArray => Oid(3907), + PgType::TsRange => Oid(3908), + PgType::TsRangeArray => Oid(3909), + PgType::TstzRange => Oid(3910), + PgType::TstzRangeArray => Oid(3911), + PgType::DateRange => Oid(3912), + PgType::DateRangeArray => Oid(3913), + PgType::Int8Range => Oid(3926), + PgType::Int8RangeArray => Oid(3927), + PgType::Jsonpath => Oid(4072), + PgType::JsonpathArray => Oid(4073), + + PgType::Custom(ty) => ty.oid, + + PgType::DeclareWithOid(oid) => *oid, + PgType::DeclareWithName(_) => { + return None; + } + PgType::DeclareArrayOf(_) => { + return None; + } + }) + } + + pub(crate) fn display_name(&self) -> &str { + match self { + PgType::Bool => "BOOL", + PgType::Bytea => "BYTEA", + PgType::Char => "\"CHAR\"", + PgType::Name => "NAME", + PgType::Int8 => "INT8", + PgType::Int2 => "INT2", + PgType::Int4 => "INT4", + PgType::Text => "TEXT", + PgType::Oid => "OID", + PgType::Json => "JSON", + PgType::JsonArray => "JSON[]", + PgType::Point => "POINT", + PgType::Lseg => "LSEG", + PgType::Path => "PATH", + PgType::Box => "BOX", + PgType::Polygon => "POLYGON", + PgType::Line => "LINE", + PgType::LineArray => "LINE[]", + PgType::Cidr => "CIDR", + PgType::CidrArray => "CIDR[]", + PgType::Float4 => "FLOAT4", + PgType::Float8 => "FLOAT8", + PgType::Unknown => "UNKNOWN", + PgType::Circle => "CIRCLE", + PgType::CircleArray => "CIRCLE[]", + PgType::Macaddr8 => "MACADDR8", + PgType::Macaddr8Array => "MACADDR8[]", + PgType::Macaddr => "MACADDR", + PgType::Inet => "INET", + PgType::BoolArray => "BOOL[]", + PgType::ByteaArray => "BYTEA[]", + PgType::CharArray => "\"CHAR\"[]", + PgType::NameArray => "NAME[]", + PgType::Int2Array => "INT2[]", + PgType::Int4Array => "INT4[]", + PgType::TextArray => "TEXT[]", + PgType::BpcharArray => "CHAR[]", + PgType::VarcharArray => "VARCHAR[]", + PgType::Int8Array => "INT8[]", + PgType::PointArray => "POINT[]", + PgType::LsegArray => "LSEG[]", + PgType::PathArray => "PATH[]", + PgType::BoxArray => "BOX[]", + PgType::Float4Array => "FLOAT4[]", + PgType::Float8Array => "FLOAT8[]", + PgType::PolygonArray => "POLYGON[]", + PgType::OidArray => "OID[]", + PgType::MacaddrArray => "MACADDR[]", + PgType::InetArray => "INET[]", + PgType::Bpchar => "CHAR", + PgType::Varchar => "VARCHAR", + PgType::Date => "DATE", + PgType::Time => "TIME", + PgType::Timestamp => "TIMESTAMP", + PgType::TimestampArray => "TIMESTAMP[]", + PgType::DateArray => "DATE[]", + PgType::TimeArray => "TIME[]", + PgType::Timestamptz => "TIMESTAMPTZ", + PgType::TimestamptzArray => "TIMESTAMPTZ[]", + PgType::Interval => "INTERVAL", + PgType::IntervalArray => "INTERVAL[]", + PgType::NumericArray => "NUMERIC[]", + PgType::Timetz => "TIMETZ", + PgType::TimetzArray => "TIMETZ[]", + PgType::Bit => "BIT", + PgType::BitArray => "BIT[]", + PgType::Varbit => "VARBIT", + PgType::VarbitArray => "VARBIT[]", + PgType::Numeric => "NUMERIC", + PgType::Record => "RECORD", + PgType::RecordArray => "RECORD[]", + PgType::Uuid => "UUID", + PgType::UuidArray => "UUID[]", + PgType::Jsonb => "JSONB", + PgType::JsonbArray => "JSONB[]", + PgType::Int4Range => "INT4RANGE", + PgType::Int4RangeArray => "INT4RANGE[]", + PgType::NumRange => "NUMRANGE", + PgType::NumRangeArray => "NUMRANGE[]", + PgType::TsRange => "TSRANGE", + PgType::TsRangeArray => "TSRANGE[]", + PgType::TstzRange => "TSTZRANGE", + PgType::TstzRangeArray => "TSTZRANGE[]", + PgType::DateRange => "DATERANGE", + PgType::DateRangeArray => "DATERANGE[]", + PgType::Int8Range => "INT8RANGE", + PgType::Int8RangeArray => "INT8RANGE[]", + PgType::Jsonpath => "JSONPATH", + PgType::JsonpathArray => "JSONPATH[]", + PgType::Money => "MONEY", + PgType::MoneyArray => "MONEY[]", + PgType::Void => "VOID", + PgType::Custom(ty) => &ty.name, + PgType::DeclareWithOid(_) => "?", + PgType::DeclareWithName(name) => name, + PgType::DeclareArrayOf(array) => &array.name, + } + } + + pub(crate) fn name(&self) -> &str { + match self { + PgType::Bool => "bool", + PgType::Bytea => "bytea", + PgType::Char => "char", + PgType::Name => "name", + PgType::Int8 => "int8", + PgType::Int2 => "int2", + PgType::Int4 => "int4", + PgType::Text => "text", + PgType::Oid => "oid", + PgType::Json => "json", + PgType::JsonArray => "_json", + PgType::Point => "point", + PgType::Lseg => "lseg", + PgType::Path => "path", + PgType::Box => "box", + PgType::Polygon => "polygon", + PgType::Line => "line", + PgType::LineArray => "_line", + PgType::Cidr => "cidr", + PgType::CidrArray => "_cidr", + PgType::Float4 => "float4", + PgType::Float8 => "float8", + PgType::Unknown => "unknown", + PgType::Circle => "circle", + PgType::CircleArray => "_circle", + PgType::Macaddr8 => "macaddr8", + PgType::Macaddr8Array => "_macaddr8", + PgType::Macaddr => "macaddr", + PgType::Inet => "inet", + PgType::BoolArray => "_bool", + PgType::ByteaArray => "_bytea", + PgType::CharArray => "_char", + PgType::NameArray => "_name", + PgType::Int2Array => "_int2", + PgType::Int4Array => "_int4", + PgType::TextArray => "_text", + PgType::BpcharArray => "_bpchar", + PgType::VarcharArray => "_varchar", + PgType::Int8Array => "_int8", + PgType::PointArray => "_point", + PgType::LsegArray => "_lseg", + PgType::PathArray => "_path", + PgType::BoxArray => "_box", + PgType::Float4Array => "_float4", + PgType::Float8Array => "_float8", + PgType::PolygonArray => "_polygon", + PgType::OidArray => "_oid", + PgType::MacaddrArray => "_macaddr", + PgType::InetArray => "_inet", + PgType::Bpchar => "bpchar", + PgType::Varchar => "varchar", + PgType::Date => "date", + PgType::Time => "time", + PgType::Timestamp => "timestamp", + PgType::TimestampArray => "_timestamp", + PgType::DateArray => "_date", + PgType::TimeArray => "_time", + PgType::Timestamptz => "timestamptz", + PgType::TimestamptzArray => "_timestamptz", + PgType::Interval => "interval", + PgType::IntervalArray => "_interval", + PgType::NumericArray => "_numeric", + PgType::Timetz => "timetz", + PgType::TimetzArray => "_timetz", + PgType::Bit => "bit", + PgType::BitArray => "_bit", + PgType::Varbit => "varbit", + PgType::VarbitArray => "_varbit", + PgType::Numeric => "numeric", + PgType::Record => "record", + PgType::RecordArray => "_record", + PgType::Uuid => "uuid", + PgType::UuidArray => "_uuid", + PgType::Jsonb => "jsonb", + PgType::JsonbArray => "_jsonb", + PgType::Int4Range => "int4range", + PgType::Int4RangeArray => "_int4range", + PgType::NumRange => "numrange", + PgType::NumRangeArray => "_numrange", + PgType::TsRange => "tsrange", + PgType::TsRangeArray => "_tsrange", + PgType::TstzRange => "tstzrange", + PgType::TstzRangeArray => "_tstzrange", + PgType::DateRange => "daterange", + PgType::DateRangeArray => "_daterange", + PgType::Int8Range => "int8range", + PgType::Int8RangeArray => "_int8range", + PgType::Jsonpath => "jsonpath", + PgType::JsonpathArray => "_jsonpath", + PgType::Money => "money", + PgType::MoneyArray => "_money", + PgType::Void => "void", + PgType::Custom(ty) => &ty.name, + PgType::DeclareWithOid(_) => "?", + PgType::DeclareWithName(name) => name, + PgType::DeclareArrayOf(array) => &array.name, + } + } + + pub(crate) fn kind(&self) -> &PgTypeKind { + match self { + PgType::Bool => &PgTypeKind::Simple, + PgType::Bytea => &PgTypeKind::Simple, + PgType::Char => &PgTypeKind::Simple, + PgType::Name => &PgTypeKind::Simple, + PgType::Int8 => &PgTypeKind::Simple, + PgType::Int2 => &PgTypeKind::Simple, + PgType::Int4 => &PgTypeKind::Simple, + PgType::Text => &PgTypeKind::Simple, + PgType::Oid => &PgTypeKind::Simple, + PgType::Json => &PgTypeKind::Simple, + PgType::JsonArray => &PgTypeKind::Array(PgTypeInfo(PgType::Json)), + PgType::Point => &PgTypeKind::Simple, + PgType::Lseg => &PgTypeKind::Simple, + PgType::Path => &PgTypeKind::Simple, + PgType::Box => &PgTypeKind::Simple, + PgType::Polygon => &PgTypeKind::Simple, + PgType::Line => &PgTypeKind::Simple, + PgType::LineArray => &PgTypeKind::Array(PgTypeInfo(PgType::Line)), + PgType::Cidr => &PgTypeKind::Simple, + PgType::CidrArray => &PgTypeKind::Array(PgTypeInfo(PgType::Cidr)), + PgType::Float4 => &PgTypeKind::Simple, + PgType::Float8 => &PgTypeKind::Simple, + PgType::Unknown => &PgTypeKind::Simple, + PgType::Circle => &PgTypeKind::Simple, + PgType::CircleArray => &PgTypeKind::Array(PgTypeInfo(PgType::Circle)), + PgType::Macaddr8 => &PgTypeKind::Simple, + PgType::Macaddr8Array => &PgTypeKind::Array(PgTypeInfo(PgType::Macaddr8)), + PgType::Macaddr => &PgTypeKind::Simple, + PgType::Inet => &PgTypeKind::Simple, + PgType::BoolArray => &PgTypeKind::Array(PgTypeInfo(PgType::Bool)), + PgType::ByteaArray => &PgTypeKind::Array(PgTypeInfo(PgType::Bytea)), + PgType::CharArray => &PgTypeKind::Array(PgTypeInfo(PgType::Char)), + PgType::NameArray => &PgTypeKind::Array(PgTypeInfo(PgType::Name)), + PgType::Int2Array => &PgTypeKind::Array(PgTypeInfo(PgType::Int2)), + PgType::Int4Array => &PgTypeKind::Array(PgTypeInfo(PgType::Int4)), + PgType::TextArray => &PgTypeKind::Array(PgTypeInfo(PgType::Text)), + PgType::BpcharArray => &PgTypeKind::Array(PgTypeInfo(PgType::Bpchar)), + PgType::VarcharArray => &PgTypeKind::Array(PgTypeInfo(PgType::Varchar)), + PgType::Int8Array => &PgTypeKind::Array(PgTypeInfo(PgType::Int8)), + PgType::PointArray => &PgTypeKind::Array(PgTypeInfo(PgType::Point)), + PgType::LsegArray => &PgTypeKind::Array(PgTypeInfo(PgType::Lseg)), + PgType::PathArray => &PgTypeKind::Array(PgTypeInfo(PgType::Path)), + PgType::BoxArray => &PgTypeKind::Array(PgTypeInfo(PgType::Box)), + PgType::Float4Array => &PgTypeKind::Array(PgTypeInfo(PgType::Float4)), + PgType::Float8Array => &PgTypeKind::Array(PgTypeInfo(PgType::Float8)), + PgType::PolygonArray => &PgTypeKind::Array(PgTypeInfo(PgType::Polygon)), + PgType::OidArray => &PgTypeKind::Array(PgTypeInfo(PgType::Oid)), + PgType::MacaddrArray => &PgTypeKind::Array(PgTypeInfo(PgType::Macaddr)), + PgType::InetArray => &PgTypeKind::Array(PgTypeInfo(PgType::Inet)), + PgType::Bpchar => &PgTypeKind::Simple, + PgType::Varchar => &PgTypeKind::Simple, + PgType::Date => &PgTypeKind::Simple, + PgType::Time => &PgTypeKind::Simple, + PgType::Timestamp => &PgTypeKind::Simple, + PgType::TimestampArray => &PgTypeKind::Array(PgTypeInfo(PgType::Timestamp)), + PgType::DateArray => &PgTypeKind::Array(PgTypeInfo(PgType::Date)), + PgType::TimeArray => &PgTypeKind::Array(PgTypeInfo(PgType::Time)), + PgType::Timestamptz => &PgTypeKind::Simple, + PgType::TimestamptzArray => &PgTypeKind::Array(PgTypeInfo(PgType::Timestamptz)), + PgType::Interval => &PgTypeKind::Simple, + PgType::IntervalArray => &PgTypeKind::Array(PgTypeInfo(PgType::Interval)), + PgType::NumericArray => &PgTypeKind::Array(PgTypeInfo(PgType::Numeric)), + PgType::Timetz => &PgTypeKind::Simple, + PgType::TimetzArray => &PgTypeKind::Array(PgTypeInfo(PgType::Timetz)), + PgType::Bit => &PgTypeKind::Simple, + PgType::BitArray => &PgTypeKind::Array(PgTypeInfo(PgType::Bit)), + PgType::Varbit => &PgTypeKind::Simple, + PgType::VarbitArray => &PgTypeKind::Array(PgTypeInfo(PgType::Varbit)), + PgType::Numeric => &PgTypeKind::Simple, + PgType::Record => &PgTypeKind::Simple, + PgType::RecordArray => &PgTypeKind::Array(PgTypeInfo(PgType::Record)), + PgType::Uuid => &PgTypeKind::Simple, + PgType::UuidArray => &PgTypeKind::Array(PgTypeInfo(PgType::Uuid)), + PgType::Jsonb => &PgTypeKind::Simple, + PgType::JsonbArray => &PgTypeKind::Array(PgTypeInfo(PgType::Jsonb)), + PgType::Int4Range => &PgTypeKind::Range(PgTypeInfo::INT4), + PgType::Int4RangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::Int4Range)), + PgType::NumRange => &PgTypeKind::Range(PgTypeInfo::NUMERIC), + PgType::NumRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::NumRange)), + PgType::TsRange => &PgTypeKind::Range(PgTypeInfo::TIMESTAMP), + PgType::TsRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::TsRange)), + PgType::TstzRange => &PgTypeKind::Range(PgTypeInfo::TIMESTAMPTZ), + PgType::TstzRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::TstzRange)), + PgType::DateRange => &PgTypeKind::Range(PgTypeInfo::DATE), + PgType::DateRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::DateRange)), + PgType::Int8Range => &PgTypeKind::Range(PgTypeInfo::INT8), + PgType::Int8RangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::Int8Range)), + PgType::Jsonpath => &PgTypeKind::Simple, + PgType::JsonpathArray => &PgTypeKind::Array(PgTypeInfo(PgType::Jsonpath)), + PgType::Money => &PgTypeKind::Simple, + PgType::MoneyArray => &PgTypeKind::Array(PgTypeInfo(PgType::Money)), + + PgType::Void => &PgTypeKind::Pseudo, + + PgType::Custom(ty) => &ty.kind, + + PgType::DeclareWithOid(oid) => { + unreachable!("(bug) use of unresolved type declaration [oid={}]", oid.0); + } + PgType::DeclareWithName(name) => { + unreachable!("(bug) use of unresolved type declaration [name={name}]"); + } + PgType::DeclareArrayOf(array) => { + unreachable!( + "(bug) use of unresolved type declaration [array of={}]", + array.elem_name + ); + } + } + } + + /// If `self` is an array type, return the type info for its element. + pub(crate) fn try_array_element(&self) -> Option> { + // We explicitly match on all the `None` cases to ensure an exhaustive match. + match self { + PgType::Bool => None, + PgType::BoolArray => Some(Cow::Owned(PgTypeInfo(PgType::Bool))), + PgType::Bytea => None, + PgType::ByteaArray => Some(Cow::Owned(PgTypeInfo(PgType::Bytea))), + PgType::Char => None, + PgType::CharArray => Some(Cow::Owned(PgTypeInfo(PgType::Char))), + PgType::Name => None, + PgType::NameArray => Some(Cow::Owned(PgTypeInfo(PgType::Name))), + PgType::Int8 => None, + PgType::Int8Array => Some(Cow::Owned(PgTypeInfo(PgType::Int8))), + PgType::Int2 => None, + PgType::Int2Array => Some(Cow::Owned(PgTypeInfo(PgType::Int2))), + PgType::Int4 => None, + PgType::Int4Array => Some(Cow::Owned(PgTypeInfo(PgType::Int4))), + PgType::Text => None, + PgType::TextArray => Some(Cow::Owned(PgTypeInfo(PgType::Text))), + PgType::Oid => None, + PgType::OidArray => Some(Cow::Owned(PgTypeInfo(PgType::Oid))), + PgType::Json => None, + PgType::JsonArray => Some(Cow::Owned(PgTypeInfo(PgType::Json))), + PgType::Point => None, + PgType::PointArray => Some(Cow::Owned(PgTypeInfo(PgType::Point))), + PgType::Lseg => None, + PgType::LsegArray => Some(Cow::Owned(PgTypeInfo(PgType::Lseg))), + PgType::Path => None, + PgType::PathArray => Some(Cow::Owned(PgTypeInfo(PgType::Path))), + PgType::Box => None, + PgType::BoxArray => Some(Cow::Owned(PgTypeInfo(PgType::Box))), + PgType::Polygon => None, + PgType::PolygonArray => Some(Cow::Owned(PgTypeInfo(PgType::Polygon))), + PgType::Line => None, + PgType::LineArray => Some(Cow::Owned(PgTypeInfo(PgType::Line))), + PgType::Cidr => None, + PgType::CidrArray => Some(Cow::Owned(PgTypeInfo(PgType::Cidr))), + PgType::Float4 => None, + PgType::Float4Array => Some(Cow::Owned(PgTypeInfo(PgType::Float4))), + PgType::Float8 => None, + PgType::Float8Array => Some(Cow::Owned(PgTypeInfo(PgType::Float8))), + PgType::Circle => None, + PgType::CircleArray => Some(Cow::Owned(PgTypeInfo(PgType::Circle))), + PgType::Macaddr8 => None, + PgType::Macaddr8Array => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr8))), + PgType::Money => None, + PgType::MoneyArray => Some(Cow::Owned(PgTypeInfo(PgType::Money))), + PgType::Macaddr => None, + PgType::MacaddrArray => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr))), + PgType::Inet => None, + PgType::InetArray => Some(Cow::Owned(PgTypeInfo(PgType::Inet))), + PgType::Bpchar => None, + PgType::BpcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Bpchar))), + PgType::Varchar => None, + PgType::VarcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Varchar))), + PgType::Date => None, + PgType::DateArray => Some(Cow::Owned(PgTypeInfo(PgType::Date))), + PgType::Time => None, + PgType::TimeArray => Some(Cow::Owned(PgTypeInfo(PgType::Time))), + PgType::Timestamp => None, + PgType::TimestampArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamp))), + PgType::Timestamptz => None, + PgType::TimestamptzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamptz))), + PgType::Interval => None, + PgType::IntervalArray => Some(Cow::Owned(PgTypeInfo(PgType::Interval))), + PgType::Timetz => None, + PgType::TimetzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timetz))), + PgType::Bit => None, + PgType::BitArray => Some(Cow::Owned(PgTypeInfo(PgType::Bit))), + PgType::Varbit => None, + PgType::VarbitArray => Some(Cow::Owned(PgTypeInfo(PgType::Varbit))), + PgType::Numeric => None, + PgType::NumericArray => Some(Cow::Owned(PgTypeInfo(PgType::Numeric))), + PgType::Record => None, + PgType::RecordArray => Some(Cow::Owned(PgTypeInfo(PgType::Record))), + PgType::Uuid => None, + PgType::UuidArray => Some(Cow::Owned(PgTypeInfo(PgType::Uuid))), + PgType::Jsonb => None, + PgType::JsonbArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonb))), + PgType::Int4Range => None, + PgType::Int4RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int4Range))), + PgType::NumRange => None, + PgType::NumRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::NumRange))), + PgType::TsRange => None, + PgType::TsRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TsRange))), + PgType::TstzRange => None, + PgType::TstzRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TstzRange))), + PgType::DateRange => None, + PgType::DateRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::DateRange))), + PgType::Int8Range => None, + PgType::Int8RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int8Range))), + PgType::Jsonpath => None, + PgType::JsonpathArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonpath))), + // There is no `UnknownArray` + PgType::Unknown => None, + // There is no `VoidArray` + PgType::Void => None, + + PgType::Custom(ty) => match &ty.kind { + PgTypeKind::Simple => None, + PgTypeKind::Pseudo => None, + PgTypeKind::Domain(_) => None, + PgTypeKind::Composite(_) => None, + PgTypeKind::Array(ref elem_type_info) => Some(Cow::Borrowed(elem_type_info)), + PgTypeKind::Enum(_) => None, + PgTypeKind::Range(_) => None, + }, + PgType::DeclareWithOid(_) => None, + PgType::DeclareWithName(name) => { + // LEGACY: infer the array element name from a `_` prefix + UStr::strip_prefix(name, "_") + .map(|elem| Cow::Owned(PgTypeInfo(PgType::DeclareWithName(elem)))) + } + PgType::DeclareArrayOf(array) => Some(Cow::Owned(PgTypeInfo(PgType::DeclareWithName( + array.elem_name.clone(), + )))), + } + } + + /// Returns `true` if this type cannot be matched by name. + fn is_declare_with_oid(&self) -> bool { + matches!(self, Self::DeclareWithOid(_)) + } + + /// Compare two `PgType`s, first by OID, then by array element, then by name. + /// + /// If `soft_eq` is true and `self` or `other` is `DeclareWithOid` but not both, return `true` + /// before checking names. + fn eq_impl(&self, other: &Self, soft_eq: bool) -> bool { + if let (Some(a), Some(b)) = (self.try_oid(), other.try_oid()) { + // If there are OIDs available, use OIDs to perform a direct match + return a == b; + } + + if soft_eq && (self.is_declare_with_oid() || other.is_declare_with_oid()) { + // If we get to this point, one instance is `DeclareWithOid()` and the other is + // `DeclareArrayOf()` or `DeclareWithName()`, which means we can't compare the two. + // + // Since this is only likely to occur when using the text protocol where we can't + // resolve type names before executing a query, we can just opt out of typechecking. + return true; + } + + if let (Some(elem_a), Some(elem_b)) = (self.try_array_element(), other.try_array_element()) + { + return elem_a == elem_b; + } + + // Otherwise, perform a match on the name + name_eq(self.name(), other.name()) + } +} + +impl TypeInfo for PgTypeInfo { + fn name(&self) -> &str { + self.0.display_name() + } + + fn is_null(&self) -> bool { + false + } + + fn is_void(&self) -> bool { + matches!(self.0, PgType::Void) + } + + fn type_compatible(&self, other: &Self) -> bool + where + Self: Sized, + { + self == other + } +} + +impl PartialEq for PgCustomType { + fn eq(&self, other: &PgCustomType) -> bool { + other.oid == self.oid + } +} + +impl PgTypeInfo { + // boolean, state of true or false + pub(crate) const BOOL: Self = Self(PgType::Bool); + pub(crate) const BOOL_ARRAY: Self = Self(PgType::BoolArray); + + // binary data types, variable-length binary string + pub(crate) const BYTEA: Self = Self(PgType::Bytea); + pub(crate) const BYTEA_ARRAY: Self = Self(PgType::ByteaArray); + + // uuid + pub(crate) const UUID: Self = Self(PgType::Uuid); + pub(crate) const UUID_ARRAY: Self = Self(PgType::UuidArray); + + // record + pub(crate) const RECORD: Self = Self(PgType::Record); + pub(crate) const RECORD_ARRAY: Self = Self(PgType::RecordArray); + + // + // JSON types + // https://www.postgresql.org/docs/current/datatype-json.html + // + + pub(crate) const JSON: Self = Self(PgType::Json); + pub(crate) const JSON_ARRAY: Self = Self(PgType::JsonArray); + + pub(crate) const JSONB: Self = Self(PgType::Jsonb); + pub(crate) const JSONB_ARRAY: Self = Self(PgType::JsonbArray); + + pub(crate) const JSONPATH: Self = Self(PgType::Jsonpath); + pub(crate) const JSONPATH_ARRAY: Self = Self(PgType::JsonpathArray); + + // + // network address types + // https://www.postgresql.org/docs/current/datatype-net-types.html + // + + pub(crate) const CIDR: Self = Self(PgType::Cidr); + pub(crate) const CIDR_ARRAY: Self = Self(PgType::CidrArray); + + pub(crate) const INET: Self = Self(PgType::Inet); + pub(crate) const INET_ARRAY: Self = Self(PgType::InetArray); + + pub(crate) const MACADDR: Self = Self(PgType::Macaddr); + pub(crate) const MACADDR_ARRAY: Self = Self(PgType::MacaddrArray); + + pub(crate) const MACADDR8: Self = Self(PgType::Macaddr8); + pub(crate) const MACADDR8_ARRAY: Self = Self(PgType::Macaddr8Array); + + // + // character types + // https://www.postgresql.org/docs/current/datatype-character.html + // + + // internal type for object names + pub(crate) const NAME: Self = Self(PgType::Name); + pub(crate) const NAME_ARRAY: Self = Self(PgType::NameArray); + + // character type, fixed-length, blank-padded + pub(crate) const BPCHAR: Self = Self(PgType::Bpchar); + pub(crate) const BPCHAR_ARRAY: Self = Self(PgType::BpcharArray); + + // character type, variable-length with limit + pub(crate) const VARCHAR: Self = Self(PgType::Varchar); + pub(crate) const VARCHAR_ARRAY: Self = Self(PgType::VarcharArray); + + // character type, variable-length + pub(crate) const TEXT: Self = Self(PgType::Text); + pub(crate) const TEXT_ARRAY: Self = Self(PgType::TextArray); + + // unknown type, transmitted as text + pub(crate) const UNKNOWN: Self = Self(PgType::Unknown); + + // + // numeric types + // https://www.postgresql.org/docs/current/datatype-numeric.html + // + + // single-byte internal type + pub(crate) const CHAR: Self = Self(PgType::Char); + pub(crate) const CHAR_ARRAY: Self = Self(PgType::CharArray); + + // internal type for type ids + pub(crate) const OID: Self = Self(PgType::Oid); + pub(crate) const OID_ARRAY: Self = Self(PgType::OidArray); + + // small-range integer; -32768 to +32767 + pub(crate) const INT2: Self = Self(PgType::Int2); + pub(crate) const INT2_ARRAY: Self = Self(PgType::Int2Array); + + // typical choice for integer; -2147483648 to +2147483647 + pub(crate) const INT4: Self = Self(PgType::Int4); + pub(crate) const INT4_ARRAY: Self = Self(PgType::Int4Array); + + // large-range integer; -9223372036854775808 to +9223372036854775807 + pub(crate) const INT8: Self = Self(PgType::Int8); + pub(crate) const INT8_ARRAY: Self = Self(PgType::Int8Array); + + // variable-precision, inexact, 6 decimal digits precision + pub(crate) const FLOAT4: Self = Self(PgType::Float4); + pub(crate) const FLOAT4_ARRAY: Self = Self(PgType::Float4Array); + + // variable-precision, inexact, 15 decimal digits precision + pub(crate) const FLOAT8: Self = Self(PgType::Float8); + pub(crate) const FLOAT8_ARRAY: Self = Self(PgType::Float8Array); + + // user-specified precision, exact + pub(crate) const NUMERIC: Self = Self(PgType::Numeric); + pub(crate) const NUMERIC_ARRAY: Self = Self(PgType::NumericArray); + + // user-specified precision, exact + pub(crate) const MONEY: Self = Self(PgType::Money); + pub(crate) const MONEY_ARRAY: Self = Self(PgType::MoneyArray); + + // + // date/time types + // https://www.postgresql.org/docs/current/datatype-datetime.html + // + + // both date and time (no time zone) + pub(crate) const TIMESTAMP: Self = Self(PgType::Timestamp); + pub(crate) const TIMESTAMP_ARRAY: Self = Self(PgType::TimestampArray); + + // both date and time (with time zone) + pub(crate) const TIMESTAMPTZ: Self = Self(PgType::Timestamptz); + pub(crate) const TIMESTAMPTZ_ARRAY: Self = Self(PgType::TimestamptzArray); + + // date (no time of day) + pub(crate) const DATE: Self = Self(PgType::Date); + pub(crate) const DATE_ARRAY: Self = Self(PgType::DateArray); + + // time of day (no date) + pub(crate) const TIME: Self = Self(PgType::Time); + pub(crate) const TIME_ARRAY: Self = Self(PgType::TimeArray); + + // time of day (no date), with time zone + pub(crate) const TIMETZ: Self = Self(PgType::Timetz); + pub(crate) const TIMETZ_ARRAY: Self = Self(PgType::TimetzArray); + + // time interval + pub(crate) const INTERVAL: Self = Self(PgType::Interval); + pub(crate) const INTERVAL_ARRAY: Self = Self(PgType::IntervalArray); + + // + // geometric types + // https://www.postgresql.org/docs/current/datatype-geometric.html + // + + // point on a plane + pub(crate) const POINT: Self = Self(PgType::Point); + pub(crate) const POINT_ARRAY: Self = Self(PgType::PointArray); + + // infinite line + pub(crate) const LINE: Self = Self(PgType::Line); + pub(crate) const LINE_ARRAY: Self = Self(PgType::LineArray); + + // finite line segment + pub(crate) const LSEG: Self = Self(PgType::Lseg); + pub(crate) const LSEG_ARRAY: Self = Self(PgType::LsegArray); + + // rectangular box + pub(crate) const BOX: Self = Self(PgType::Box); + pub(crate) const BOX_ARRAY: Self = Self(PgType::BoxArray); + + // open or closed path + pub(crate) const PATH: Self = Self(PgType::Path); + pub(crate) const PATH_ARRAY: Self = Self(PgType::PathArray); + + // polygon + pub(crate) const POLYGON: Self = Self(PgType::Polygon); + pub(crate) const POLYGON_ARRAY: Self = Self(PgType::PolygonArray); + + // circle + pub(crate) const CIRCLE: Self = Self(PgType::Circle); + pub(crate) const CIRCLE_ARRAY: Self = Self(PgType::CircleArray); + + // + // bit string types + // https://www.postgresql.org/docs/current/datatype-bit.html + // + + pub(crate) const BIT: Self = Self(PgType::Bit); + pub(crate) const BIT_ARRAY: Self = Self(PgType::BitArray); + + pub(crate) const VARBIT: Self = Self(PgType::Varbit); + pub(crate) const VARBIT_ARRAY: Self = Self(PgType::VarbitArray); + + // + // range types + // https://www.postgresql.org/docs/current/rangetypes.html + // + + pub(crate) const INT4_RANGE: Self = Self(PgType::Int4Range); + pub(crate) const INT4_RANGE_ARRAY: Self = Self(PgType::Int4RangeArray); + + pub(crate) const NUM_RANGE: Self = Self(PgType::NumRange); + pub(crate) const NUM_RANGE_ARRAY: Self = Self(PgType::NumRangeArray); + + pub(crate) const TS_RANGE: Self = Self(PgType::TsRange); + pub(crate) const TS_RANGE_ARRAY: Self = Self(PgType::TsRangeArray); + + pub(crate) const TSTZ_RANGE: Self = Self(PgType::TstzRange); + pub(crate) const TSTZ_RANGE_ARRAY: Self = Self(PgType::TstzRangeArray); + + pub(crate) const DATE_RANGE: Self = Self(PgType::DateRange); + pub(crate) const DATE_RANGE_ARRAY: Self = Self(PgType::DateRangeArray); + + pub(crate) const INT8_RANGE: Self = Self(PgType::Int8Range); + pub(crate) const INT8_RANGE_ARRAY: Self = Self(PgType::Int8RangeArray); + + // + // pseudo types + // https://www.postgresql.org/docs/9.3/datatype-pseudo.html + // + + pub(crate) const VOID: Self = Self(PgType::Void); +} + +impl Display for PgTypeInfo { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.pad(self.name()) + } +} + +impl PartialEq for PgType { + fn eq(&self, other: &PgType) -> bool { + self.eq_impl(other, true) + } +} + +/// Check type names for equality, respecting Postgres' case sensitivity rules for identifiers. +/// +/// https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +fn name_eq(name1: &str, name2: &str) -> bool { + // Cop-out of processing Unicode escapes by just using string equality. + if name1.starts_with("U&") { + // If `name2` doesn't start with `U&` this will automatically be `false`. + return name1 == name2; + } + + let mut chars1 = identifier_chars(name1); + let mut chars2 = identifier_chars(name2); + + while let (Some(a), Some(b)) = (chars1.next(), chars2.next()) { + if !a.eq(&b) { + return false; + } + } + + chars1.next().is_none() && chars2.next().is_none() +} + +struct IdentifierChar { + ch: char, + case_sensitive: bool, +} + +impl IdentifierChar { + fn eq(&self, other: &Self) -> bool { + if self.case_sensitive || other.case_sensitive { + self.ch == other.ch + } else { + self.ch.eq_ignore_ascii_case(&other.ch) + } + } +} + +/// Return an iterator over all significant characters of an identifier. +/// +/// Ignores non-escaped quotation marks. +fn identifier_chars(ident: &str) -> impl Iterator + '_ { + let mut case_sensitive = false; + let mut last_char_quote = false; + + ident.chars().filter_map(move |ch| { + if ch == '"' { + if last_char_quote { + last_char_quote = false; + } else { + last_char_quote = true; + return None; + } + } else if last_char_quote { + last_char_quote = false; + case_sensitive = !case_sensitive; + } + + Some(IdentifierChar { ch, case_sensitive }) + }) +} + +#[test] +fn test_name_eq() { + let test_values = [ + ("foo", "foo", true), + ("foo", "Foo", true), + ("foo", "FOO", true), + ("foo", r#""foo""#, true), + ("foo", r#""Foo""#, false), + ("foo", "foo.foo", false), + ("foo.foo", "foo.foo", true), + ("foo.foo", "foo.Foo", true), + ("foo.foo", "foo.FOO", true), + ("foo.foo", "Foo.foo", true), + ("foo.foo", "Foo.Foo", true), + ("foo.foo", "FOO.FOO", true), + ("foo.foo", "foo", false), + ("foo.foo", r#"foo."foo""#, true), + ("foo.foo", r#"foo."Foo""#, false), + ("foo.foo", r#"foo."FOO""#, false), + ]; + + for (left, right, eq) in test_values { + assert_eq!( + name_eq(left, right), + eq, + "failed check for name_eq({left:?}, {right:?})" + ); + assert_eq!( + name_eq(right, left), + eq, + "failed check for name_eq({right:?}, {left:?})" + ); + } +} diff --git a/patches/sqlx-postgres/src/types/array.rs b/patches/sqlx-postgres/src/types/array.rs new file mode 100644 index 000000000..9b8be6341 --- /dev/null +++ b/patches/sqlx-postgres/src/types/array.rs @@ -0,0 +1,356 @@ +use sqlx_core::bytes::Buf; +use sqlx_core::types::Text; +use std::borrow::Cow; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::type_info::PgType; +use crate::types::Oid; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +/// Provides information necessary to encode and decode Postgres arrays as compatible Rust types. +/// +/// Implementing this trait for some type `T` enables relevant `Type`,`Encode` and `Decode` impls +/// for `Vec`, `&[T]` (slices), `[T; N]` (arrays), etc. +/// +/// ### Note: `#[derive(sqlx::Type)]` +/// If you have the `postgres` feature enabled, `#[derive(sqlx::Type)]` will also generate +/// an impl of this trait for your type if your wrapper is marked `#[sqlx(transparent)]`: +/// +/// ```rust,ignore +/// #[derive(sqlx::Type)] +/// #[sqlx(transparent)] +/// struct UserId(i64); +/// +/// let user_ids: Vec = sqlx::query_scalar("select '{ 123, 456 }'::int8[]") +/// .fetch(&mut pg_connection) +/// .await?; +/// ``` +/// +/// However, this may cause an error if the type being wrapped does not implement `PgHasArrayType`, +/// e.g. `Vec` itself, because we don't currently support multidimensional arrays: +/// +/// ```rust,ignore +/// #[derive(sqlx::Type)] // ERROR: `Vec` does not implement `PgHasArrayType` +/// #[sqlx(transparent)] +/// struct UserIds(Vec); +/// ``` +/// +/// To remedy this, add `#[sqlx(no_pg_array)]`, which disables the generation +/// of the `PgHasArrayType` impl: +/// +/// ```rust,ignore +/// #[derive(sqlx::Type)] +/// #[sqlx(transparent, no_pg_array)] +/// struct UserIds(Vec); +/// ``` +/// +/// See [the documentation of `Type`][Type] for more details. +pub trait PgHasArrayType { + fn array_type_info() -> PgTypeInfo; + fn array_compatible(ty: &PgTypeInfo) -> bool { + *ty == Self::array_type_info() + } +} + +impl PgHasArrayType for &T +where + T: PgHasArrayType, +{ + fn array_type_info() -> PgTypeInfo { + T::array_type_info() + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + T::array_compatible(ty) + } +} + +impl PgHasArrayType for Option +where + T: PgHasArrayType, +{ + fn array_type_info() -> PgTypeInfo { + T::array_type_info() + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + T::array_compatible(ty) + } +} + +impl PgHasArrayType for Text { + fn array_type_info() -> PgTypeInfo { + String::array_type_info() + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + String::array_compatible(ty) + } +} + +impl Type for [T] +where + T: PgHasArrayType, +{ + fn type_info() -> PgTypeInfo { + T::array_type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + T::array_compatible(ty) + } +} + +impl Type for Vec +where + T: PgHasArrayType, +{ + fn type_info() -> PgTypeInfo { + T::array_type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + T::array_compatible(ty) + } +} + +impl Type for [T; N] +where + T: PgHasArrayType, +{ + fn type_info() -> PgTypeInfo { + T::array_type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + T::array_compatible(ty) + } +} + +impl<'q, T> Encode<'q, Postgres> for Vec +where + for<'a> &'a [T]: Encode<'q, Postgres>, + T: Encode<'q, Postgres>, +{ + #[inline] + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.as_slice().encode_by_ref(buf) + } +} + +impl<'q, T, const N: usize> Encode<'q, Postgres> for [T; N] +where + for<'a> &'a [T]: Encode<'q, Postgres>, + T: Encode<'q, Postgres>, +{ + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.as_slice().encode_by_ref(buf) + } +} + +impl<'q, T> Encode<'q, Postgres> for &'_ [T] +where + T: Encode<'q, Postgres> + Type, +{ + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + let type_info = self + .first() + .and_then(Encode::produces) + .unwrap_or_else(T::type_info); + + buf.extend(&1_i32.to_be_bytes()); // number of dimensions + buf.extend(&0_i32.to_be_bytes()); // flags + + // element type + match type_info.0 { + PgType::DeclareWithName(name) => buf.patch_type_by_name(&name), + PgType::DeclareArrayOf(array) => buf.patch_array_type(array), + + ty => { + buf.extend(&ty.oid().0.to_be_bytes()); + } + } + + let array_len = i32::try_from(self.len()).map_err(|_| { + format!( + "encoded array length is too large for Postgres: {}", + self.len() + ) + })?; + + buf.extend(array_len.to_be_bytes()); // len + buf.extend(&1_i32.to_be_bytes()); // lower bound + + for element in self.iter() { + buf.encode(element)?; + } + + Ok(IsNull::No) + } +} + +impl<'r, T, const N: usize> Decode<'r, Postgres> for [T; N] +where + T: for<'a> Decode<'a, Postgres> + Type, +{ + fn decode(value: PgValueRef<'r>) -> Result { + // This could be done more efficiently by refactoring the Vec decoding below so that it can + // be used for arrays and Vec. + let vec: Vec = Decode::decode(value)?; + let array: [T; N] = vec.try_into().map_err(|_| "wrong number of elements")?; + Ok(array) + } +} + +impl<'r, T> Decode<'r, Postgres> for Vec +where + T: for<'a> Decode<'a, Postgres> + Type, +{ + fn decode(value: PgValueRef<'r>) -> Result { + let format = value.format(); + + match format { + PgValueFormat::Binary => { + // https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L1548 + + let mut buf = value.as_bytes()?; + + // number of dimensions in the array + let ndim = buf.get_i32(); + + if ndim == 0 { + // zero dimensions is an empty array + return Ok(Vec::new()); + } + + if ndim != 1 { + return Err(format!("encountered an array of {ndim} dimensions; only one-dimensional arrays are supported").into()); + } + + // appears to have been used in the past to communicate potential NULLS + // but reading source code back through our supported postgres versions (9.5+) + // this is never used for anything + let _flags = buf.get_i32(); + + // the OID of the element + let element_type_oid = Oid(buf.get_u32()); + let element_type_info: PgTypeInfo = PgTypeInfo::try_from_oid(element_type_oid) + .or_else(|| value.type_info.try_array_element().map(Cow::into_owned)) + .ok_or_else(|| { + BoxDynError::from(format!( + "failed to resolve array element type for oid {}", + element_type_oid.0 + )) + })?; + + // length of the array axis + let len = buf.get_i32(); + + let len = usize::try_from(len) + .map_err(|_| format!("overflow converting array len ({len}) to usize"))?; + + // the lower bound, we only support arrays starting from "1" + let lower = buf.get_i32(); + + if lower != 1 { + return Err(format!("encountered an array with a lower bound of {lower} in the first dimension; only arrays starting at one are supported").into()); + } + + let mut elements = Vec::with_capacity(len); + + for _ in 0..len { + let value_ref = PgValueRef::get(&mut buf, format, element_type_info.clone())?; + + elements.push(T::decode(value_ref)?); + } + + Ok(elements) + } + + PgValueFormat::Text => { + // no type is provided from the database for the element + let element_type_info = T::type_info(); + + let s = value.as_str()?; + + // https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L718 + + // trim the wrapping braces + let s = &s[1..(s.len() - 1)]; + + if s.is_empty() { + // short-circuit empty arrays up here + return Ok(Vec::new()); + } + + // NOTE: Nearly *all* types use ',' as the sequence delimiter. Yes, there is one + // that does not. The BOX (not PostGIS) type uses ';' as a delimiter. + + // TODO: When we add support for BOX we need to figure out some way to make the + // delimiter selection + + let delimiter = ','; + let mut done = false; + let mut in_quotes = false; + let mut in_escape = false; + let mut value = String::with_capacity(10); + let mut chars = s.chars(); + let mut elements = Vec::with_capacity(4); + + while !done { + loop { + match chars.next() { + Some(ch) => match ch { + _ if in_escape => { + value.push(ch); + in_escape = false; + } + + '"' => { + in_quotes = !in_quotes; + } + + '\\' => { + in_escape = true; + } + + _ if ch == delimiter && !in_quotes => { + break; + } + + _ => { + value.push(ch); + } + }, + + None => { + done = true; + break; + } + } + } + + let value_opt = if value == "NULL" { + None + } else { + Some(value.as_bytes()) + }; + + elements.push(T::decode(PgValueRef { + value: value_opt, + row: None, + type_info: element_type_info.clone(), + format, + })?); + + value.clear(); + } + + Ok(elements) + } + } + } +} diff --git a/patches/sqlx-postgres/src/types/bigdecimal-range.md b/patches/sqlx-postgres/src/types/bigdecimal-range.md new file mode 100644 index 000000000..5d4ee502b --- /dev/null +++ b/patches/sqlx-postgres/src/types/bigdecimal-range.md @@ -0,0 +1,20 @@ +#### Note: `BigDecimal` Has a Larger Range than `NUMERIC` +`BigDecimal` can represent values with a far, far greater range than the `NUMERIC` type in Postgres can. + +`NUMERIC` is limited to 131,072 digits before the decimal point, and 16,384 digits after it. +See [Section 8.1, Numeric Types] of the Postgres manual for details. + +Meanwhile, `BigDecimal` can theoretically represent a value with an arbitrary number of decimal digits, albeit +with a maximum of 263 significant figures. + +Because encoding in the current API design _must_ be infallible, +when attempting to encode a `BigDecimal` that cannot fit in the wire representation of `NUMERIC`, +SQLx may instead encode a sentinel value that falls outside the allowed range but is still representable. + +This will cause the query to return a `DatabaseError` with code `22P03` (`invalid_binary_representation`) +and the error message `invalid scale in external "numeric" value` (though this may be subject to change). + +However, `BigDecimal` should be able to decode any `NUMERIC` value except `NaN`, +for which it has no representation. + +[Section 8.1, Numeric Types]: https://www.postgresql.org/docs/current/datatype-numeric.html diff --git a/patches/sqlx-postgres/src/types/bigdecimal.rs b/patches/sqlx-postgres/src/types/bigdecimal.rs new file mode 100644 index 000000000..869f85079 --- /dev/null +++ b/patches/sqlx-postgres/src/types/bigdecimal.rs @@ -0,0 +1,477 @@ +use bigdecimal::BigDecimal; +use num_bigint::{BigInt, Sign}; +use std::cmp; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::numeric::{PgNumeric, PgNumericSign}; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +impl Type for BigDecimal { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUMERIC + } +} + +impl PgHasArrayType for BigDecimal { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::NUMERIC_ARRAY + } +} + +impl TryFrom for BigDecimal { + type Error = BoxDynError; + + fn try_from(numeric: PgNumeric) -> Result { + Self::try_from(&numeric) + } +} + +impl TryFrom<&'_ PgNumeric> for BigDecimal { + type Error = BoxDynError; + + fn try_from(numeric: &'_ PgNumeric) -> Result { + let (digits, sign, weight) = match *numeric { + PgNumeric::Number { + ref digits, + sign, + weight, + .. + } => (digits, sign, weight), + + PgNumeric::NotANumber => { + return Err("BigDecimal does not support NaN values".into()); + } + }; + + if digits.is_empty() { + // Postgres returns an empty digit array for 0 but BigInt expects at least one zero + return Ok(0u64.into()); + } + + let sign = match sign { + PgNumericSign::Positive => Sign::Plus, + PgNumericSign::Negative => Sign::Minus, + }; + + // weight is 0 if the decimal point falls after the first base-10000 digit + // + // `Vec` capacity cannot exceed `isize::MAX` bytes, so this cast can't wrap in practice. + #[allow(clippy::cast_possible_wrap)] + let scale = (digits.len() as i64 - weight as i64 - 1) * 4; + + // no optimized algorithm for base-10 so use base-100 for faster processing + let mut cents = Vec::with_capacity(digits.len() * 2); + + #[allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss + )] + for (i, &digit) in digits.iter().enumerate() { + if !PgNumeric::is_valid_digit(digit) { + return Err(format!( + "PgNumeric to BigDecimal: {i}th digit is out of range {digit}" + ) + .into()); + } + + cents.push((digit / 100) as u8); + cents.push((digit % 100) as u8); + } + + let bigint = BigInt::from_radix_be(sign, ¢s, 100) + .ok_or("PgNumeric contained an out-of-range digit")?; + + Ok(BigDecimal::new(bigint, scale)) + } +} + +impl TryFrom<&'_ BigDecimal> for PgNumeric { + type Error = BoxDynError; + + fn try_from(decimal: &BigDecimal) -> Result { + let base_10_to_10000 = |chunk: &[u8]| chunk.iter().fold(0i16, |a, &d| a * 10 + d as i16); + + // NOTE: this unfortunately copies the BigInt internally + let (integer, exp) = decimal.as_bigint_and_exponent(); + + // this routine is specifically optimized for base-10 + // FIXME: is there a way to iterate over the digits to avoid the Vec allocation + let (sign, base_10) = integer.to_radix_be(10); + + let base_10_len = i64::try_from(base_10.len()).map_err(|_| { + format!( + "BigDecimal base-10 length out of range for PgNumeric: {}", + base_10.len() + ) + })?; + + // weight is positive power of 10000 + // exp is the negative power of 10 + let weight_10 = base_10_len - exp; + + // scale is only nonzero when we have fractional digits + // since `exp` is the _negative_ decimal exponent, it tells us + // exactly what our scale should be + let scale: i16 = cmp::max(0, exp).try_into()?; + + // there's an implicit +1 offset in the interpretation + let weight: i16 = if weight_10 <= 0 { + weight_10 / 4 - 1 + } else { + // the `-1` is a fix for an off by 1 error (4 digits should still be 0 weight) + (weight_10 - 1) / 4 + } + .try_into()?; + + let digits_len = if base_10.len() % 4 != 0 { + base_10.len() / 4 + 1 + } else { + base_10.len() / 4 + }; + + // For efficiency, we want to process the base-10 digits in chunks of 4, + // but that means we need to deal with the non-divisible remainder first. + let offset = weight_10.rem_euclid(4); + + // Do a checked conversion to the smallest integer, + // so we can widen arbitrarily without triggering lints. + let offset = u8::try_from(offset).unwrap_or_else(|_| { + panic!("BUG: `offset` should be in the range [0, 4) but is {offset}") + }); + + let mut digits = Vec::with_capacity(digits_len); + + if let Some(first) = base_10.get(..offset as usize) { + if !first.is_empty() { + digits.push(base_10_to_10000(first)); + } + } else if offset != 0 { + // If we didn't hit the `if let Some` branch, + // then `base_10.len()` must strictly be smaller + #[allow(clippy::cast_possible_truncation)] + let power = (offset as usize - base_10.len()) as u32; + + digits.push(base_10_to_10000(&base_10) * 10i16.pow(power)); + } + + if let Some(rest) = base_10.get(offset as usize..) { + // `chunk.len()` is always between 1 and 4 + #[allow(clippy::cast_possible_truncation)] + digits.extend( + rest.chunks(4) + .map(|chunk| base_10_to_10000(chunk) * 10i16.pow(4 - chunk.len() as u32)), + ); + } + + while let Some(&0) = digits.last() { + digits.pop(); + } + + Ok(PgNumeric::Number { + sign: sign_to_pg(sign), + scale, + weight, + digits, + }) + } +} + +#[doc=include_str!("bigdecimal-range.md")] +impl Encode<'_, Postgres> for BigDecimal { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + PgNumeric::try_from(self)?.encode(buf)?; + + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + PgNumeric::size_hint(self.digits()) + } +} + +/// ### Note: `NaN` +/// `BigDecimal` has a greater range than `NUMERIC` (see the corresponding `Encode` impl for details) +/// but cannot represent `NaN`, so decoding may return an error. +impl Decode<'_, Postgres> for BigDecimal { + fn decode(value: PgValueRef<'_>) -> Result { + match value.format() { + PgValueFormat::Binary => PgNumeric::decode(value.as_bytes()?)?.try_into(), + PgValueFormat::Text => Ok(value.as_str()?.parse::()?), + } + } +} + +fn sign_to_pg(sign: Sign) -> PgNumericSign { + match sign { + Sign::Plus | Sign::NoSign => PgNumericSign::Positive, + Sign::Minus => PgNumericSign::Negative, + } +} + +#[cfg(test)] +mod bigdecimal_to_pgnumeric { + use super::{BigDecimal, PgNumeric, PgNumericSign}; + use std::convert::TryFrom; + + #[test] + fn zero() { + let zero: BigDecimal = "0".parse().unwrap(); + + assert_eq!( + PgNumeric::try_from(&zero).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![] + } + ); + } + + #[test] + fn one() { + let one: BigDecimal = "1".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![1] + } + ); + } + + #[test] + fn ten() { + let ten: BigDecimal = "10".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&ten).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![10] + } + ); + } + + #[test] + fn one_hundred() { + let one_hundred: BigDecimal = "100".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one_hundred).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![100] + } + ); + } + + #[test] + fn ten_thousand() { + // BigDecimal doesn't normalize here + let ten_thousand: BigDecimal = "10000".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&ten_thousand).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 1, + digits: vec![1] + } + ); + } + + #[test] + fn two_digits() { + let two_digits: BigDecimal = "12345".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&two_digits).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 1, + digits: vec![1, 2345] + } + ); + } + + #[test] + fn one_tenth() { + let one_tenth: BigDecimal = "0.1".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one_tenth).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 1, + weight: -1, + digits: vec![1000] + } + ); + } + + #[test] + fn one_hundredth() { + let one_hundredth: BigDecimal = "0.01".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one_hundredth).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 2, + weight: -1, + digits: vec![100] + } + ); + } + + #[test] + fn twelve_thousandths() { + let twelve_thousandths: BigDecimal = "0.012".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&twelve_thousandths).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 3, + weight: -1, + digits: vec![120] + } + ); + } + + #[test] + fn decimal_1() { + let decimal: BigDecimal = "1.2345".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 4, + weight: 0, + digits: vec![1, 2345] + } + ); + } + + #[test] + fn decimal_2() { + let decimal: BigDecimal = "0.12345".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 5, + weight: -1, + digits: vec![1234, 5000] + } + ); + } + + #[test] + fn decimal_3() { + let decimal: BigDecimal = "0.01234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 5, + weight: -1, + digits: vec![0123, 4000] + } + ); + } + + #[test] + fn decimal_4() { + let decimal: BigDecimal = "12345.67890".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 5, + weight: 1, + digits: vec![1, 2345, 6789] + } + ); + } + + #[test] + fn one_digit_decimal() { + let one_digit_decimal: BigDecimal = "0.00001234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one_digit_decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 8, + weight: -2, + digits: vec![1234] + } + ); + } + + #[test] + fn issue_423_four_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let four_digit: BigDecimal = "1234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&four_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![1234] + } + ); + } + + #[test] + fn issue_423_negative_four_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let negative_four_digit: BigDecimal = "-1234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&negative_four_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Negative, + scale: 0, + weight: 0, + digits: vec![1234] + } + ); + } + + #[test] + fn issue_423_eight_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let eight_digit: BigDecimal = "12345678".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&eight_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 1, + digits: vec![1234, 5678] + } + ); + } + + #[test] + fn issue_423_negative_eight_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let negative_eight_digit: BigDecimal = "-12345678".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&negative_eight_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Negative, + scale: 0, + weight: 1, + digits: vec![1234, 5678] + } + ); + } +} diff --git a/patches/sqlx-postgres/src/types/bit_vec.rs b/patches/sqlx-postgres/src/types/bit_vec.rs new file mode 100644 index 000000000..b519a5f24 --- /dev/null +++ b/patches/sqlx-postgres/src/types/bit_vec.rs @@ -0,0 +1,99 @@ +use crate::arguments::value_size_int4_checked; +use crate::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, + PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres, +}; +use bit_vec::BitVec; +use sqlx_core::bytes::Buf; +use std::{io, mem}; + +impl Type for BitVec { + fn type_info() -> PgTypeInfo { + PgTypeInfo::VARBIT + } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::BIT || *ty == PgTypeInfo::VARBIT + } +} + +impl PgHasArrayType for BitVec { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::VARBIT_ARRAY + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::BIT_ARRAY || *ty == PgTypeInfo::VARBIT_ARRAY + } +} + +impl Encode<'_, Postgres> for BitVec { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + let len = value_size_int4_checked(self.len())?; + + buf.extend(len.to_be_bytes()); + buf.extend(self.to_bytes()); + + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + mem::size_of::() + self.len() + } +} + +impl Decode<'_, Postgres> for BitVec { + fn decode(value: PgValueRef<'_>) -> Result { + match value.format() { + PgValueFormat::Binary => { + let mut bytes = value.as_bytes()?; + let len = bytes.get_i32(); + + let len = usize::try_from(len).map_err(|_| format!("invalid VARBIT len: {len}"))?; + + // The smallest amount of data we can read is one byte + let bytes_len = (len + 7) / 8; + + if bytes.remaining() != bytes_len { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "VARBIT length mismatch.", + ))?; + } + + let mut bitvec = BitVec::from_bytes(bytes); + + // Chop off zeroes from the back. We get bits in bytes, so if + // our bitvec is not in full bytes, extra zeroes are added to + // the end. + while bitvec.len() > len { + bitvec.pop(); + } + + Ok(bitvec) + } + PgValueFormat::Text => { + let s = value.as_str()?; + let mut bit_vec = BitVec::with_capacity(s.len()); + + for c in s.chars() { + match c { + '0' => bit_vec.push(false), + '1' => bit_vec.push(true), + _ => { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "VARBIT data contains other characters than 1 or 0.", + ))?; + } + } + } + + Ok(bit_vec) + } + } + } +} diff --git a/patches/sqlx-postgres/src/types/bool.rs b/patches/sqlx-postgres/src/types/bool.rs new file mode 100644 index 000000000..8c3e140d3 --- /dev/null +++ b/patches/sqlx-postgres/src/types/bool.rs @@ -0,0 +1,42 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +impl Type for bool { + fn type_info() -> PgTypeInfo { + PgTypeInfo::BOOL + } +} + +impl PgHasArrayType for bool { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::BOOL_ARRAY + } +} + +impl Encode<'_, Postgres> for bool { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.push(*self as u8); + + Ok(IsNull::No) + } +} + +impl Decode<'_, Postgres> for bool { + fn decode(value: PgValueRef<'_>) -> Result { + Ok(match value.format() { + PgValueFormat::Binary => value.as_bytes()?[0] != 0, + + PgValueFormat::Text => match value.as_str()? { + "t" => true, + "f" => false, + + s => { + return Err(format!("unexpected value {s:?} for boolean").into()); + } + }, + }) + } +} diff --git a/patches/sqlx-postgres/src/types/bytes.rs b/patches/sqlx-postgres/src/types/bytes.rs new file mode 100644 index 000000000..45968837a --- /dev/null +++ b/patches/sqlx-postgres/src/types/bytes.rs @@ -0,0 +1,112 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +impl PgHasArrayType for u8 { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::BYTEA + } +} + +impl PgHasArrayType for &'_ [u8] { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::BYTEA_ARRAY + } +} + +impl PgHasArrayType for Box<[u8]> { + fn array_type_info() -> PgTypeInfo { + <[&[u8]] as Type>::type_info() + } +} + +impl PgHasArrayType for Vec { + fn array_type_info() -> PgTypeInfo { + <[&[u8]] as Type>::type_info() + } +} + +impl PgHasArrayType for [u8; N] { + fn array_type_info() -> PgTypeInfo { + <[&[u8]] as Type>::type_info() + } +} + +impl Encode<'_, Postgres> for &'_ [u8] { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend_from_slice(self); + + Ok(IsNull::No) + } +} + +impl Encode<'_, Postgres> for Box<[u8]> { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + <&[u8] as Encode>::encode(self.as_ref(), buf) + } +} + +impl Encode<'_, Postgres> for Vec { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + <&[u8] as Encode>::encode(self, buf) + } +} + +impl Encode<'_, Postgres> for [u8; N] { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + <&[u8] as Encode>::encode(self.as_slice(), buf) + } +} + +impl<'r> Decode<'r, Postgres> for &'r [u8] { + fn decode(value: PgValueRef<'r>) -> Result { + match value.format() { + PgValueFormat::Binary => value.as_bytes(), + PgValueFormat::Text => { + Err("unsupported decode to `&[u8]` of BYTEA in a simple query; use a prepared query or decode to `Vec`".into()) + } + } + } +} + +fn text_hex_decode_input(value: PgValueRef<'_>) -> Result<&[u8], BoxDynError> { + // BYTEA is formatted as \x followed by hex characters + value + .as_bytes()? + .strip_prefix(b"\\x") + .ok_or("text does not start with \\x") + .map_err(Into::into) +} + +impl Decode<'_, Postgres> for Box<[u8]> { + fn decode(value: PgValueRef<'_>) -> Result { + Ok(match value.format() { + PgValueFormat::Binary => Box::from(value.as_bytes()?), + PgValueFormat::Text => Box::from(hex::decode(text_hex_decode_input(value)?)?), + }) + } +} + +impl Decode<'_, Postgres> for Vec { + fn decode(value: PgValueRef<'_>) -> Result { + Ok(match value.format() { + PgValueFormat::Binary => value.as_bytes()?.to_owned(), + PgValueFormat::Text => hex::decode(text_hex_decode_input(value)?)?, + }) + } +} + +impl Decode<'_, Postgres> for [u8; N] { + fn decode(value: PgValueRef<'_>) -> Result { + let mut bytes = [0u8; N]; + match value.format() { + PgValueFormat::Binary => { + bytes = value.as_bytes()?.try_into()?; + } + PgValueFormat::Text => hex::decode_to_slice(text_hex_decode_input(value)?, &mut bytes)?, + }; + Ok(bytes) + } +} diff --git a/patches/sqlx-postgres/src/types/chrono/date.rs b/patches/sqlx-postgres/src/types/chrono/date.rs new file mode 100644 index 000000000..0327d5c45 --- /dev/null +++ b/patches/sqlx-postgres/src/types/chrono/date.rs @@ -0,0 +1,64 @@ +use std::mem; + +use chrono::{NaiveDate, TimeDelta}; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +impl Type for NaiveDate { + fn type_info() -> PgTypeInfo { + PgTypeInfo::DATE + } +} + +impl PgHasArrayType for NaiveDate { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::DATE_ARRAY + } +} + +impl Encode<'_, Postgres> for NaiveDate { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + // DATE is encoded as the days since epoch + let days: i32 = (*self - postgres_epoch_date()) + .num_days() + .try_into() + .map_err(|_| { + format!("value {self:?} would overflow binary encoding for Postgres DATE") + })?; + + Encode::::encode(days, buf) + } + + fn size_hint(&self) -> usize { + mem::size_of::() + } +} + +impl<'r> Decode<'r, Postgres> for NaiveDate { + fn decode(value: PgValueRef<'r>) -> Result { + Ok(match value.format() { + PgValueFormat::Binary => { + // DATE is encoded as the days since epoch + let days: i32 = Decode::::decode(value)?; + + let days = TimeDelta::try_days(days.into()) + .unwrap_or_else(|| { + unreachable!("BUG: days ({days}) as `i32` multiplied into seconds should not overflow `i64`") + }); + + postgres_epoch_date() + days + } + + PgValueFormat::Text => NaiveDate::parse_from_str(value.as_str()?, "%Y-%m-%d")?, + }) + } +} + +#[inline] +fn postgres_epoch_date() -> NaiveDate { + NaiveDate::from_ymd_opt(2000, 1, 1).expect("expected 2000-01-01 to be a valid NaiveDate") +} diff --git a/patches/sqlx-postgres/src/types/chrono/datetime.rs b/patches/sqlx-postgres/src/types/chrono/datetime.rs new file mode 100644 index 000000000..2dceb9e93 --- /dev/null +++ b/patches/sqlx-postgres/src/types/chrono/datetime.rs @@ -0,0 +1,114 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use chrono::{ + DateTime, Duration, FixedOffset, Local, NaiveDate, NaiveDateTime, Offset, TimeZone, Utc, +}; +use std::mem; + +impl Type for NaiveDateTime { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TIMESTAMP + } +} + +impl Type for DateTime { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TIMESTAMPTZ + } +} + +impl PgHasArrayType for NaiveDateTime { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::TIMESTAMP_ARRAY + } +} + +impl PgHasArrayType for DateTime { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::TIMESTAMPTZ_ARRAY + } +} + +impl Encode<'_, Postgres> for NaiveDateTime { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + // TIMESTAMP is encoded as the microseconds since the epoch + let micros = (*self - postgres_epoch_datetime()) + .num_microseconds() + .ok_or_else(|| format!("NaiveDateTime out of range for Postgres: {self:?}"))?; + + Encode::::encode(micros, buf) + } + + fn size_hint(&self) -> usize { + mem::size_of::() + } +} + +impl<'r> Decode<'r, Postgres> for NaiveDateTime { + fn decode(value: PgValueRef<'r>) -> Result { + Ok(match value.format() { + PgValueFormat::Binary => { + // TIMESTAMP is encoded as the microseconds since the epoch + let us = Decode::::decode(value)?; + postgres_epoch_datetime() + Duration::microseconds(us) + } + + PgValueFormat::Text => { + let s = value.as_str()?; + NaiveDateTime::parse_from_str( + s, + if s.contains('+') { + // Contains a time-zone specifier + // This is given for timestamptz for some reason + // Postgres already guarantees this to always be UTC + "%Y-%m-%d %H:%M:%S%.f%#z" + } else { + "%Y-%m-%d %H:%M:%S%.f" + }, + )? + } + }) + } +} + +impl Encode<'_, Postgres> for DateTime { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + Encode::::encode(self.naive_utc(), buf) + } + + fn size_hint(&self) -> usize { + mem::size_of::() + } +} + +impl<'r> Decode<'r, Postgres> for DateTime { + fn decode(value: PgValueRef<'r>) -> Result { + let naive = >::decode(value)?; + Ok(Local.from_utc_datetime(&naive)) + } +} + +impl<'r> Decode<'r, Postgres> for DateTime { + fn decode(value: PgValueRef<'r>) -> Result { + let naive = >::decode(value)?; + Ok(Utc.from_utc_datetime(&naive)) + } +} + +impl<'r> Decode<'r, Postgres> for DateTime { + fn decode(value: PgValueRef<'r>) -> Result { + let naive = >::decode(value)?; + Ok(Utc.fix().from_utc_datetime(&naive)) + } +} + +#[inline] +fn postgres_epoch_datetime() -> NaiveDateTime { + NaiveDate::from_ymd_opt(2000, 1, 1) + .expect("expected 2000-01-01 to be a valid NaiveDate") + .and_hms_opt(0, 0, 0) + .expect("expected 2000-01-01T00:00:00 to be a valid NaiveDateTime") +} diff --git a/patches/sqlx-postgres/src/types/chrono/mod.rs b/patches/sqlx-postgres/src/types/chrono/mod.rs new file mode 100644 index 000000000..bd27c4d2d --- /dev/null +++ b/patches/sqlx-postgres/src/types/chrono/mod.rs @@ -0,0 +1,3 @@ +mod date; +mod datetime; +mod time; diff --git a/patches/sqlx-postgres/src/types/chrono/time.rs b/patches/sqlx-postgres/src/types/chrono/time.rs new file mode 100644 index 000000000..ca66f389f --- /dev/null +++ b/patches/sqlx-postgres/src/types/chrono/time.rs @@ -0,0 +1,58 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use chrono::{Duration, NaiveTime}; +use std::mem; + +impl Type for NaiveTime { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TIME + } +} + +impl PgHasArrayType for NaiveTime { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::TIME_ARRAY + } +} + +impl Encode<'_, Postgres> for NaiveTime { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + // TIME is encoded as the microseconds since midnight + let micros = (*self - NaiveTime::default()) + .num_microseconds() + .ok_or_else(|| format!("Time out of range for PostgreSQL: {self}"))?; + + Encode::::encode(micros, buf) + } + + fn size_hint(&self) -> usize { + mem::size_of::() + } +} + +impl<'r> Decode<'r, Postgres> for NaiveTime { + fn decode(value: PgValueRef<'r>) -> Result { + Ok(match value.format() { + PgValueFormat::Binary => { + // TIME is encoded as the microseconds since midnight + let us: i64 = Decode::::decode(value)?; + NaiveTime::default() + Duration::microseconds(us) + } + + PgValueFormat::Text => NaiveTime::parse_from_str(value.as_str()?, "%H:%M:%S%.f")?, + }) + } +} + +#[test] +fn check_naive_time_default_is_midnight() { + // Just a canary in case this changes. + assert_eq!( + NaiveTime::from_hms_opt(0, 0, 0), + Some(NaiveTime::default()), + "implementation assumes `NaiveTime::default()` equals midnight" + ); +} diff --git a/patches/sqlx-postgres/src/types/citext.rs b/patches/sqlx-postgres/src/types/citext.rs new file mode 100644 index 000000000..c0316ac82 --- /dev/null +++ b/patches/sqlx-postgres/src/types/citext.rs @@ -0,0 +1,106 @@ +use crate::types::array_compatible; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres}; +use sqlx_core::decode::Decode; +use sqlx_core::encode::{Encode, IsNull}; +use sqlx_core::error::BoxDynError; +use sqlx_core::types::Type; +use std::fmt; +use std::fmt::{Debug, Display, Formatter}; +use std::ops::Deref; +use std::str::FromStr; + +/// Case-insensitive text (`citext`) support for Postgres. +/// +/// Note that SQLx considers the `citext` type to be compatible with `String` +/// and its various derivatives, so direct usage of this type is generally unnecessary. +/// +/// However, it may be needed, for example, when binding a `citext[]` array, +/// as Postgres will generally not accept a `text[]` array (mapped from `Vec`) in its place. +/// +/// See [the Postgres manual, Appendix F, Section 10][PG.F.10] for details on using `citext`. +/// +/// [PG.F.10]: https://www.postgresql.org/docs/current/citext.html +/// +/// ### Note: Extension Required +/// The `citext` extension is not enabled by default in Postgres. You will need to do so explicitly: +/// +/// ```ignore +/// CREATE EXTENSION IF NOT EXISTS "citext"; +/// ``` +/// +/// ### Note: `PartialEq` is Case-Sensitive +/// This type derives `PartialEq` which forwards to the implementation on `String`, which +/// is case-sensitive. This impl exists mainly for testing. +/// +/// To properly emulate the case-insensitivity of `citext` would require use of locale-aware +/// functions in `libc`, and even then would require querying the locale of the database server +/// and setting it locally, which is unsafe. +#[derive(Clone, Debug, Default, PartialEq)] +pub struct PgCiText(pub String); + +impl Type for PgCiText { + fn type_info() -> PgTypeInfo { + // Since `citext` is enabled by an extension, it does not have a stable OID. + PgTypeInfo::with_name("citext") + } + + fn compatible(ty: &PgTypeInfo) -> bool { + <&str as Type>::compatible(ty) + } +} + +impl Deref for PgCiText { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.0.as_str() + } +} + +impl From for PgCiText { + fn from(value: String) -> Self { + Self(value) + } +} + +impl From for String { + fn from(value: PgCiText) -> Self { + value.0 + } +} + +impl FromStr for PgCiText { + type Err = core::convert::Infallible; + + fn from_str(s: &str) -> Result { + Ok(PgCiText(s.parse()?)) + } +} + +impl Display for PgCiText { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str(&self.0) + } +} + +impl PgHasArrayType for PgCiText { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_citext") + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + array_compatible::<&str>(ty) + } +} + +impl Encode<'_, Postgres> for PgCiText { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + <&str as Encode>::encode(&**self, buf) + } +} + +impl Decode<'_, Postgres> for PgCiText { + fn decode(value: PgValueRef<'_>) -> Result { + Ok(PgCiText(value.as_str()?.to_owned())) + } +} diff --git a/patches/sqlx-postgres/src/types/cube.rs b/patches/sqlx-postgres/src/types/cube.rs new file mode 100644 index 000000000..3dd5f59f1 --- /dev/null +++ b/patches/sqlx-postgres/src/types/cube.rs @@ -0,0 +1,537 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use sqlx_core::Error; +use std::str::FromStr; +use core::mem::size_of; + +const BYTE_WIDTH: usize = 8; + +/// +const MAX_DIMENSIONS: usize = 100; + +const IS_POINT_FLAG: u32 = 1 << 31; + +// FIXME(breaking): these variants are confusingly named and structured +// consider changing them or making this an opaque wrapper around `Vec` +#[derive(Debug, Clone, PartialEq)] +pub enum PgCube { + /// A one-dimensional point. + // FIXME: `Point1D(f64) + Point(f64), + /// An N-dimensional point ("represented internally as a zero-volume cube"). + // FIXME: `PointND(f64)` + ZeroVolume(Vec), + + /// A one-dimensional interval with starting and ending points. + // FIXME: `Interval1D { start: f64, end: f64 }` + OneDimensionInterval(f64, f64), + + // FIXME: add `Cube3D { lower_left: [f64; 3], upper_right: [f64; 3] }`? + /// An N-dimensional cube with points representing lower-left and upper-right corners, respectively. + // FIXME: CubeND { lower_left: Vec, upper_right: Vec }` + MultiDimension(Vec>), +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +struct Header { + dimensions: usize, + is_point: bool, +} + +#[derive(Debug, thiserror::Error)] +#[error("error decoding CUBE (is_point: {is_point}, dimensions: {dimensions})")] +struct DecodeError { + is_point: bool, + dimensions: usize, + message: String, +} + +impl Type for PgCube { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("cube") + } +} + +impl PgHasArrayType for PgCube { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_cube") + } +} + +impl<'r> Decode<'r, Postgres> for PgCube { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgCube::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgCube::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgCube { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("cube")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + self.header().encoded_size() + } +} + +impl FromStr for PgCube { + type Err = Error; + + fn from_str(s: &str) -> Result { + let content = s + .trim_start_matches('(') + .trim_start_matches('[') + .trim_end_matches(')') + .trim_end_matches(']') + .replace(' ', ""); + + if !content.contains('(') && !content.contains(',') { + return parse_point(&content); + } + + if !content.contains("),(") { + return parse_zero_volume(&content); + } + + let point_vecs = content.split("),(").collect::>(); + if point_vecs.len() == 2 && !point_vecs.iter().any(|pv| pv.contains(',')) { + return parse_one_dimensional_interval(point_vecs); + } + + parse_multidimensional_interval(point_vecs) + } +} + +impl PgCube { + fn header(&self) -> Header { + match self { + PgCube::Point(..) => Header { + is_point: true, + dimensions: 1, + }, + PgCube::ZeroVolume(values) => Header { + is_point: true, + dimensions: values.len(), + }, + PgCube::OneDimensionInterval(..) => Header { + is_point: false, + dimensions: 1, + }, + PgCube::MultiDimension(multi_values) => Header { + is_point: false, + dimensions: multi_values.first().map(|arr| arr.len()).unwrap_or(0), + }, + } + } + + fn from_bytes(mut bytes: &[u8]) -> Result { + let header = Header::try_read(&mut bytes)?; + + if bytes.len() != header.data_size() { + return Err(DecodeError::new( + &header, + format!( + "expected {} bytes after header, got {}", + header.data_size(), + bytes.len() + ), + ) + .into()); + } + + match (header.is_point, header.dimensions) { + (true, 1) => Ok(PgCube::Point(bytes.get_f64())), + (true, _) => Ok(PgCube::ZeroVolume( + read_vec(&mut bytes).map_err(|e| DecodeError::new(&header, e))?, + )), + (false, 1) => Ok(PgCube::OneDimensionInterval( + bytes.get_f64(), + bytes.get_f64(), + )), + (false, _) => Ok(PgCube::MultiDimension(read_cube(&header, bytes)?)), + } + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + let header = self.header(); + + buff.reserve(header.data_size()); + + header.try_write(buff)?; + + match self { + PgCube::Point(value) => { + buff.extend_from_slice(&value.to_be_bytes()); + } + PgCube::ZeroVolume(values) => { + buff.extend(values.iter().flat_map(|v| v.to_be_bytes())); + } + PgCube::OneDimensionInterval(x, y) => { + buff.extend_from_slice(&x.to_be_bytes()); + buff.extend_from_slice(&y.to_be_bytes()); + } + PgCube::MultiDimension(multi_values) => { + if multi_values.len() != 2 { + return Err(format!("invalid CUBE value: {self:?}")); + } + + buff.extend( + multi_values + .iter() + .flat_map(|point| point.iter().flat_map(|scalar| scalar.to_be_bytes())), + ); + } + }; + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +fn read_vec(bytes: &mut &[u8]) -> Result, String> { + if bytes.len() % BYTE_WIDTH != 0 { + return Err(format!( + "data length not divisible by {BYTE_WIDTH}: {}", + bytes.len() + )); + } + + let mut out = Vec::with_capacity(bytes.len() / BYTE_WIDTH); + + while bytes.has_remaining() { + out.push(bytes.get_f64()); + } + + Ok(out) +} + +fn read_cube(header: &Header, mut bytes: &[u8]) -> Result>, String> { + if bytes.len() != header.data_size() { + return Err(format!( + "expected {} bytes, got {}", + header.data_size(), + bytes.len() + )); + } + + let mut out = Vec::with_capacity(2); + + // Expecting exactly 2 N-dimensional points + for _ in 0..2 { + let mut point = Vec::new(); + + for _ in 0..header.dimensions { + point.push(bytes.get_f64()); + } + + out.push(point); + } + + Ok(out) +} + +fn parse_float_from_str(s: &str, error_msg: &str) -> Result { + s.parse().map_err(|_| Error::Decode(error_msg.into())) +} + +fn parse_point(str: &str) -> Result { + Ok(PgCube::Point(parse_float_from_str( + str, + "Failed to parse point", + )?)) +} + +fn parse_zero_volume(content: &str) -> Result { + content + .split(',') + .map(|p| parse_float_from_str(p, "Failed to parse into zero-volume cube")) + .collect::, _>>() + .map(PgCube::ZeroVolume) +} + +fn parse_one_dimensional_interval(point_vecs: Vec<&str>) -> Result { + let x = parse_float_from_str( + &remove_parentheses(point_vecs.first().ok_or(Error::Decode( + format!("Could not decode cube interval x: {:?}", point_vecs).into(), + ))?), + "Failed to parse X in one-dimensional interval", + )?; + let y = parse_float_from_str( + &remove_parentheses(point_vecs.get(1).ok_or(Error::Decode( + format!("Could not decode cube interval y: {:?}", point_vecs).into(), + ))?), + "Failed to parse Y in one-dimensional interval", + )?; + Ok(PgCube::OneDimensionInterval(x, y)) +} + +fn parse_multidimensional_interval(point_vecs: Vec<&str>) -> Result { + point_vecs + .iter() + .map(|&point_vec| { + point_vec + .split(',') + .map(|point| { + parse_float_from_str( + &remove_parentheses(point), + "Failed to parse into multi-dimension cube", + ) + }) + .collect::, _>>() + }) + .collect::, _>>() + .map(PgCube::MultiDimension) +} + +fn remove_parentheses(s: &str) -> String { + s.trim_matches(|c| c == '(' || c == ')').to_string() +} + +impl Header { + const PACKED_WIDTH: usize = size_of::(); + + fn encoded_size(&self) -> usize { + Self::PACKED_WIDTH + self.data_size() + } + + fn data_size(&self) -> usize { + if self.is_point { + self.dimensions * BYTE_WIDTH + } else { + self.dimensions * BYTE_WIDTH * 2 + } + } + + fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + if self.dimensions > MAX_DIMENSIONS { + return Err(format!( + "CUBE dimensionality exceeds allowed maximum ({} > {MAX_DIMENSIONS})", + self.dimensions + )); + } + + // Cannot overflow thanks to the above check. + #[allow(clippy::cast_possible_truncation)] + let mut packed = self.dimensions as u32; + + // https://github.com/postgres/postgres/blob/e3ec9dc1bf4983fcedb6f43c71ea12ee26aefc7a/contrib/cube/cubedata.h#L18-L24 + if self.is_point { + packed |= IS_POINT_FLAG; + } + + buff.extend(packed.to_be_bytes()); + + Ok(()) + } + + fn try_read(buf: &mut &[u8]) -> Result { + if buf.len() < Self::PACKED_WIDTH { + return Err(format!( + "expected CUBE data to contain at least {} bytes, got {}", + Self::PACKED_WIDTH, + buf.len() + )); + } + + let packed = buf.get_u32(); + + let is_point = packed & IS_POINT_FLAG != 0; + let dimensions = packed & !IS_POINT_FLAG; + + // can only overflow on 16-bit platforms + let dimensions = usize::try_from(dimensions) + .ok() + .filter(|&it| it <= MAX_DIMENSIONS) + .ok_or_else(|| format!("received CUBE data with higher than expected dimensionality: {dimensions} (is_point: {is_point})"))?; + + Ok(Self { + is_point, + dimensions, + }) + } +} + +impl DecodeError { + fn new(header: &Header, message: String) -> Self { + DecodeError { + is_point: header.is_point, + dimensions: header.dimensions, + message, + } + } +} + +#[cfg(test)] +mod cube_tests { + + use std::str::FromStr; + + use super::PgCube; + + const POINT_BYTES: &[u8] = &[128, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 0]; + const ZERO_VOLUME_BYTES: &[u8] = &[ + 128, 0, 0, 2, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, + ]; + const ONE_DIMENSIONAL_INTERVAL_BYTES: &[u8] = &[ + 0, 0, 0, 1, 64, 28, 0, 0, 0, 0, 0, 0, 64, 32, 0, 0, 0, 0, 0, 0, + ]; + const MULTI_DIMENSION_2_DIM_BYTES: &[u8] = &[ + 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, + 64, 16, 0, 0, 0, 0, 0, 0, + ]; + const MULTI_DIMENSION_3_DIM_BYTES: &[u8] = &[ + 0, 0, 0, 3, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, 64, 16, 0, 0, 0, 0, 0, 0, 64, + 20, 0, 0, 0, 0, 0, 0, 64, 24, 0, 0, 0, 0, 0, 0, 64, 28, 0, 0, 0, 0, 0, 0, + ]; + + #[test] + fn can_deserialise_point_type_byes() { + let cube = PgCube::from_bytes(POINT_BYTES).unwrap(); + assert_eq!(cube, PgCube::Point(2.)) + } + + #[test] + fn can_deserialise_point_type_str() { + let cube_1 = PgCube::from_str("(2)").unwrap(); + assert_eq!(cube_1, PgCube::Point(2.)); + let cube_2 = PgCube::from_str("2").unwrap(); + assert_eq!(cube_2, PgCube::Point(2.)); + } + + #[test] + fn can_serialise_point_type() { + assert_eq!(PgCube::Point(2.).serialize_to_vec(), POINT_BYTES,) + } + #[test] + fn can_deserialise_zero_volume_bytes() { + let cube = PgCube::from_bytes(ZERO_VOLUME_BYTES).unwrap(); + assert_eq!(cube, PgCube::ZeroVolume(vec![2., 3.])); + } + + #[test] + fn can_deserialise_zero_volume_string() { + let cube_1 = PgCube::from_str("(2,3,4)").unwrap(); + assert_eq!(cube_1, PgCube::ZeroVolume(vec![2., 3., 4.])); + let cube_2 = PgCube::from_str("2,3,4").unwrap(); + assert_eq!(cube_2, PgCube::ZeroVolume(vec![2., 3., 4.])); + } + + #[test] + fn can_serialise_zero_volume() { + assert_eq!( + PgCube::ZeroVolume(vec![2., 3.]).serialize_to_vec(), + ZERO_VOLUME_BYTES + ); + } + + #[test] + fn can_deserialise_one_dimension_interval_bytes() { + let cube = PgCube::from_bytes(ONE_DIMENSIONAL_INTERVAL_BYTES).unwrap(); + assert_eq!(cube, PgCube::OneDimensionInterval(7., 8.)) + } + + #[test] + fn can_deserialise_one_dimension_interval_string() { + let cube_1 = PgCube::from_str("((7),(8))").unwrap(); + assert_eq!(cube_1, PgCube::OneDimensionInterval(7., 8.)); + let cube_2 = PgCube::from_str("(7),(8)").unwrap(); + assert_eq!(cube_2, PgCube::OneDimensionInterval(7., 8.)); + } + + #[test] + fn can_serialise_one_dimension_interval() { + assert_eq!( + PgCube::OneDimensionInterval(7., 8.).serialize_to_vec(), + ONE_DIMENSIONAL_INTERVAL_BYTES + ) + } + + #[test] + fn can_deserialise_multi_dimension_2_dimension_byte() { + let cube = PgCube::from_bytes(MULTI_DIMENSION_2_DIM_BYTES).unwrap(); + assert_eq!( + cube, + PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]]) + ) + } + + #[test] + fn can_deserialise_multi_dimension_2_dimension_string() { + let cube_1 = PgCube::from_str("((1,2),(3,4))").unwrap(); + assert_eq!( + cube_1, + PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]]) + ); + let cube_2 = PgCube::from_str("((1, 2), (3, 4))").unwrap(); + assert_eq!( + cube_2, + PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]]) + ); + let cube_3 = PgCube::from_str("(1,2),(3,4)").unwrap(); + assert_eq!( + cube_3, + PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]]) + ); + let cube_4 = PgCube::from_str("(1, 2), (3, 4)").unwrap(); + assert_eq!( + cube_4, + PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]]) + ) + } + + #[test] + fn can_serialise_multi_dimension_2_dimension() { + assert_eq!( + PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]]).serialize_to_vec(), + MULTI_DIMENSION_2_DIM_BYTES + ) + } + + #[test] + fn can_deserialise_multi_dimension_3_dimension_bytes() { + let cube = PgCube::from_bytes(MULTI_DIMENSION_3_DIM_BYTES).unwrap(); + assert_eq!( + cube, + PgCube::MultiDimension(vec![vec![2., 3., 4.], vec![5., 6., 7.]]) + ) + } + + #[test] + fn can_deserialise_multi_dimension_3_dimension_string() { + let cube = PgCube::from_str("((2,3,4),(5,6,7))").unwrap(); + assert_eq!( + cube, + PgCube::MultiDimension(vec![vec![2., 3., 4.], vec![5., 6., 7.]]) + ); + let cube_2 = PgCube::from_str("(2,3,4),(5,6,7)").unwrap(); + assert_eq!( + cube_2, + PgCube::MultiDimension(vec![vec![2., 3., 4.], vec![5., 6., 7.]]) + ); + } + + #[test] + fn can_serialise_multi_dimension_3_dimension() { + assert_eq!( + PgCube::MultiDimension(vec![vec![2., 3., 4.], vec![5., 6., 7.]]).serialize_to_vec(), + MULTI_DIMENSION_3_DIM_BYTES + ) + } +} diff --git a/patches/sqlx-postgres/src/types/float.rs b/patches/sqlx-postgres/src/types/float.rs new file mode 100644 index 000000000..116a28c2d --- /dev/null +++ b/patches/sqlx-postgres/src/types/float.rs @@ -0,0 +1,65 @@ +use byteorder::{BigEndian, ByteOrder}; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +impl Type for f32 { + fn type_info() -> PgTypeInfo { + PgTypeInfo::FLOAT4 + } +} + +impl PgHasArrayType for f32 { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::FLOAT4_ARRAY + } +} + +impl Encode<'_, Postgres> for f32 { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend(&self.to_be_bytes()); + + Ok(IsNull::No) + } +} + +impl Decode<'_, Postgres> for f32 { + fn decode(value: PgValueRef<'_>) -> Result { + Ok(match value.format() { + PgValueFormat::Binary => BigEndian::read_f32(value.as_bytes()?), + PgValueFormat::Text => value.as_str()?.parse()?, + }) + } +} + +impl Type for f64 { + fn type_info() -> PgTypeInfo { + PgTypeInfo::FLOAT8 + } +} + +impl PgHasArrayType for f64 { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::FLOAT8_ARRAY + } +} + +impl Encode<'_, Postgres> for f64 { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend(&self.to_be_bytes()); + + Ok(IsNull::No) + } +} + +impl Decode<'_, Postgres> for f64 { + fn decode(value: PgValueRef<'_>) -> Result { + Ok(match value.format() { + PgValueFormat::Binary => BigEndian::read_f64(value.as_bytes()?), + PgValueFormat::Text => value.as_str()?.parse()?, + }) + } +} diff --git a/patches/sqlx-postgres/src/types/hstore.rs b/patches/sqlx-postgres/src/types/hstore.rs new file mode 100644 index 000000000..bb61cc547 --- /dev/null +++ b/patches/sqlx-postgres/src/types/hstore.rs @@ -0,0 +1,323 @@ +use std::{ + collections::{btree_map, BTreeMap}, + mem::size_of, + ops::{Deref, DerefMut}, + str, +}; + +use crate::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, + PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres, +}; +use serde::{Deserialize, Serialize}; +use sqlx_core::bytes::Buf; + +/// Key-value support (`hstore`) for Postgres. +/// +/// SQLx currently maps `hstore` to a `BTreeMap>` but this may be expanded in +/// future to allow for user defined types. +/// +/// See [the Postgres manual, Appendix F, Section 18][PG.F.18] +/// +/// [PG.F.18]: https://www.postgresql.org/docs/current/hstore.html +/// +/// ### Note: Requires Postgres 8.3+ +/// Introduced as a method for storing unstructured data, the `hstore` extension was first added in +/// Postgres 8.3. +/// +/// +/// ### Note: Extension Required +/// The `hstore` extension is not enabled by default in Postgres. You will need to do so explicitly: +/// +/// ```ignore +/// CREATE EXTENSION IF NOT EXISTS hstore; +/// ``` +/// +/// # Examples +/// +/// ``` +/// # use sqlx_postgres::types::PgHstore; +/// // Shows basic usage of the PgHstore type. +/// // +/// #[derive(Clone, Debug, Default, Eq, PartialEq)] +/// struct UserCreate<'a> { +/// username: &'a str, +/// password: &'a str, +/// additional_data: PgHstore +/// } +/// +/// let mut new_user = UserCreate { +/// username: "name.surname@email.com", +/// password: "@super_secret_1", +/// ..Default::default() +/// }; +/// +/// new_user.additional_data.insert("department".to_string(), Some("IT".to_string())); +/// new_user.additional_data.insert("equipment_issued".to_string(), None); +/// ``` +/// ```ignore +/// query_scalar::<_, i64>( +/// "insert into user(username, password, additional_data) values($1, $2, $3) returning id" +/// ) +/// .bind(new_user.username) +/// .bind(new_user.password) +/// .bind(new_user.additional_data) +/// .fetch_one(pg_conn) +/// .await?; +/// ``` +/// +/// ``` +/// # use sqlx_postgres::types::PgHstore; +/// // PgHstore implements FromIterator to simplify construction. +/// // +/// let additional_data = PgHstore::from_iter([ +/// ("department".to_string(), Some("IT".to_string())), +/// ("equipment_issued".to_string(), None), +/// ]); +/// +/// assert_eq!(additional_data["department"], Some("IT".to_string())); +/// assert_eq!(additional_data["equipment_issued"], None); +/// +/// // Also IntoIterator for ease of iteration. +/// // +/// for (key, value) in additional_data { +/// println!("{key}: {value:?}"); +/// } +/// ``` +/// +#[derive(Clone, Debug, Default, Eq, PartialEq, Deserialize, Serialize)] +pub struct PgHstore(pub BTreeMap>); + +impl Deref for PgHstore { + type Target = BTreeMap>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for PgHstore { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl FromIterator<(String, String)> for PgHstore { + fn from_iter>(iter: T) -> Self { + iter.into_iter().map(|(k, v)| (k, Some(v))).collect() + } +} + +impl FromIterator<(String, Option)> for PgHstore { + fn from_iter)>>(iter: T) -> Self { + let mut result = Self::default(); + + for (key, value) in iter { + result.0.insert(key, value); + } + + result + } +} + +impl IntoIterator for PgHstore { + type Item = (String, Option); + type IntoIter = btree_map::IntoIter>; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl Type for PgHstore { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("hstore") + } +} + +impl<'r> Decode<'r, Postgres> for PgHstore { + fn decode(value: PgValueRef<'r>) -> Result { + let mut buf = <&[u8] as Decode>::decode(value)?; + let len = read_length(&mut buf)?; + + let len = + usize::try_from(len).map_err(|_| format!("PgHstore: length out of range: {len}"))?; + + let mut result = Self::default(); + + for i in 0..len { + let key = read_string(&mut buf) + .map_err(|e| format!("PgHstore: error reading {i}th key: {e}"))? + .ok_or_else(|| format!("PgHstore: expected {i}th key, got nothing"))?; + + let value = read_string(&mut buf) + .map_err(|e| format!("PgHstore: error reading value for key {key:?}: {e}"))?; + + result.insert(key, value); + } + + if !buf.is_empty() { + tracing::warn!("{} unread bytes at the end of HSTORE value", buf.len()); + } + + Ok(result) + } +} + +impl Encode<'_, Postgres> for PgHstore { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend_from_slice(&i32::to_be_bytes( + self.0 + .len() + .try_into() + .map_err(|_| format!("PgHstore length out of range: {}", self.0.len()))?, + )); + + for (i, (key, val)) in self.0.iter().enumerate() { + let key_bytes = key.as_bytes(); + + let key_len = i32::try_from(key_bytes.len()).map_err(|_| { + // Doesn't make sense to print the key itself: it's more than 2 GiB long! + format!( + "PgHstore: length of {i}th key out of range: {} bytes", + key_bytes.len() + ) + })?; + + buf.extend_from_slice(&i32::to_be_bytes(key_len)); + buf.extend_from_slice(key_bytes); + + match val { + Some(val) => { + let val_bytes = val.as_bytes(); + + let val_len = i32::try_from(val_bytes.len()).map_err(|_| { + format!( + "PgHstore: value length for key {key:?} out of range: {} bytes", + val_bytes.len() + ) + })?; + buf.extend_from_slice(&i32::to_be_bytes(val_len)); + buf.extend_from_slice(val_bytes); + } + None => { + buf.extend_from_slice(&i32::to_be_bytes(-1)); + } + } + } + + Ok(IsNull::No) + } +} + +fn read_length(buf: &mut &[u8]) -> Result { + if buf.len() < size_of::() { + return Err(format!( + "expected {} bytes, got {}", + size_of::(), + buf.len() + )); + } + + Ok(buf.get_i32()) +} + +fn read_string(buf: &mut &[u8]) -> Result, String> { + let len = read_length(buf)?; + + match len { + -1 => Ok(None), + len => { + let len = + usize::try_from(len).map_err(|_| format!("string length out of range: {len}"))?; + + if buf.len() < len { + return Err(format!("expected {len} bytes, got {}", buf.len())); + } + + let (val, rest) = buf.split_at(len); + *buf = rest; + + Ok(Some( + str::from_utf8(val).map_err(|e| e.to_string())?.to_string(), + )) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::PgValueFormat; + + const EMPTY: &str = "00000000"; + + const NAME_SURNAME_AGE: &str = + "0000000300000003616765ffffffff000000046e616d65000000044a6f686e000000077375726e616d6500000003446f65"; + + #[test] + fn hstore_deserialize_ok() { + let empty = hex::decode(EMPTY).unwrap(); + let name_surname_age = hex::decode(NAME_SURNAME_AGE).unwrap(); + + let empty = PgValueRef { + value: Some(empty.as_slice()), + row: None, + type_info: PgTypeInfo::with_name("hstore"), + format: PgValueFormat::Binary, + }; + + let name_surname = PgValueRef { + value: Some(name_surname_age.as_slice()), + row: None, + type_info: PgTypeInfo::with_name("hstore"), + format: PgValueFormat::Binary, + }; + + let res_empty = PgHstore::decode(empty).unwrap(); + let res_name_surname = PgHstore::decode(name_surname).unwrap(); + + assert!(res_empty.is_empty()); + assert_eq!(res_name_surname["name"], Some("John".to_string())); + assert_eq!(res_name_surname["surname"], Some("Doe".to_string())); + assert_eq!(res_name_surname["age"], None); + } + + #[test] + #[should_panic(expected = "PgHstore: length out of range: -5")] + fn hstore_deserialize_buffer_length_error() { + let buf = PgValueRef { + value: Some(&[255, 255, 255, 251]), + row: None, + type_info: PgTypeInfo::with_name("hstore"), + format: PgValueFormat::Binary, + }; + + PgHstore::decode(buf).unwrap(); + } + + #[test] + fn hstore_serialize_ok() { + let mut buff = PgArgumentBuffer::default(); + let _ = PgHstore::from_iter::<[(String, String); 0]>([]) + .encode_by_ref(&mut buff) + .unwrap(); + + assert_eq!(hex::encode(buff.as_slice()), EMPTY); + + buff.clear(); + + let _ = PgHstore::from_iter([ + ("name".to_string(), Some("John".to_string())), + ("surname".to_string(), Some("Doe".to_string())), + ("age".to_string(), None), + ]) + .encode_by_ref(&mut buff) + .unwrap(); + + assert_eq!(hex::encode(buff.as_slice()), NAME_SURNAME_AGE); + } +} diff --git a/patches/sqlx-postgres/src/types/int.rs b/patches/sqlx-postgres/src/types/int.rs new file mode 100644 index 000000000..b8255f1b0 --- /dev/null +++ b/patches/sqlx-postgres/src/types/int.rs @@ -0,0 +1,176 @@ +use byteorder::{BigEndian, ByteOrder}; +use std::num::{NonZeroI16, NonZeroI32, NonZeroI64}; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +fn int_decode(value: PgValueRef<'_>) -> Result { + Ok(match value.format() { + PgValueFormat::Text => value.as_str()?.parse()?, + PgValueFormat::Binary => { + let buf = value.as_bytes()?; + + // Return error if buf is empty or is more than 8 bytes + match buf.len() { + 0 => { + return Err("Value Buffer found empty while decoding to integer type".into()); + } + buf_len @ 9.. => { + return Err(format!( + "Value Buffer exceeds 8 bytes while decoding to integer type. Buffer size = {} bytes ", buf_len + ) + .into()); + } + _ => {} + } + + BigEndian::read_int(buf, buf.len()) + } + }) +} + +impl Type for i8 { + fn type_info() -> PgTypeInfo { + PgTypeInfo::CHAR + } +} + +impl PgHasArrayType for i8 { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::CHAR_ARRAY + } +} + +impl Encode<'_, Postgres> for i8 { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend(&self.to_be_bytes()); + + Ok(IsNull::No) + } +} + +impl Decode<'_, Postgres> for i8 { + fn decode(value: PgValueRef<'_>) -> Result { + // note: decoding here is for the `"char"` type as Postgres does not have a native 1-byte integer type. + // https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/char.c#L58-L60 + match value.format() { + PgValueFormat::Binary => int_decode(value)?.try_into().map_err(Into::into), + PgValueFormat::Text => { + let text = value.as_str()?; + + // A value of 0 is represented with the empty string. + if text.is_empty() { + return Ok(0); + } + + if text.starts_with('\\') { + // For values between 0x80 and 0xFF, it's encoded in octal. + return Ok(i8::from_str_radix(text.trim_start_matches('\\'), 8)?); + } + + // Wrapping is the whole idea. + #[allow(clippy::cast_possible_wrap)] + Ok(text.as_bytes()[0] as i8) + } + } + } +} + +impl Type for i16 { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INT2 + } +} + +impl PgHasArrayType for i16 { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INT2_ARRAY + } +} + +impl Encode<'_, Postgres> for i16 { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend(&self.to_be_bytes()); + + Ok(IsNull::No) + } +} + +impl Decode<'_, Postgres> for i16 { + fn decode(value: PgValueRef<'_>) -> Result { + int_decode(value)?.try_into().map_err(Into::into) + } +} + +impl Type for i32 { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INT4 + } +} + +impl PgHasArrayType for i32 { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INT4_ARRAY + } +} + +impl Encode<'_, Postgres> for i32 { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend(&self.to_be_bytes()); + + Ok(IsNull::No) + } +} + +impl Decode<'_, Postgres> for i32 { + fn decode(value: PgValueRef<'_>) -> Result { + int_decode(value)?.try_into().map_err(Into::into) + } +} + +impl Type for i64 { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INT8 + } +} + +impl PgHasArrayType for i64 { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INT8_ARRAY + } +} + +impl Encode<'_, Postgres> for i64 { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend(&self.to_be_bytes()); + + Ok(IsNull::No) + } +} + +impl Decode<'_, Postgres> for i64 { + fn decode(value: PgValueRef<'_>) -> Result { + int_decode(value) + } +} + +impl PgHasArrayType for NonZeroI16 { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INT2_ARRAY + } +} + +impl PgHasArrayType for NonZeroI32 { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INT4_ARRAY + } +} + +impl PgHasArrayType for NonZeroI64 { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INT8_ARRAY + } +} diff --git a/patches/sqlx-postgres/src/types/interval.rs b/patches/sqlx-postgres/src/types/interval.rs new file mode 100644 index 000000000..52ab54991 --- /dev/null +++ b/patches/sqlx-postgres/src/types/interval.rs @@ -0,0 +1,399 @@ +use std::mem; + +use byteorder::{NetworkEndian, ReadBytesExt}; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +// `PgInterval` is available for direct access to the INTERVAL type + +#[derive(Debug, Eq, PartialEq, Clone, Hash, Default)] +pub struct PgInterval { + pub months: i32, + pub days: i32, + pub microseconds: i64, +} + +impl Type for PgInterval { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INTERVAL + } +} + +impl PgHasArrayType for PgInterval { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INTERVAL_ARRAY + } +} + +impl<'de> Decode<'de, Postgres> for PgInterval { + fn decode(value: PgValueRef<'de>) -> Result { + match value.format() { + PgValueFormat::Binary => { + let mut buf = value.as_bytes()?; + let microseconds = buf.read_i64::()?; + let days = buf.read_i32::()?; + let months = buf.read_i32::()?; + + Ok(PgInterval { + months, + days, + microseconds, + }) + } + + // TODO: Implement parsing of text mode + PgValueFormat::Text => { + Err("not implemented: decode `INTERVAL` in text mode (unprepared queries)".into()) + } + } + } +} + +impl Encode<'_, Postgres> for PgInterval { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend(&self.microseconds.to_be_bytes()); + buf.extend(&self.days.to_be_bytes()); + buf.extend(&self.months.to_be_bytes()); + + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + 2 * mem::size_of::() + } +} + +// We then implement Encode + Type for std Duration, chrono Duration, and time Duration +// This is to enable ease-of-use for encoding when its simple + +impl Type for std::time::Duration { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INTERVAL + } +} + +impl PgHasArrayType for std::time::Duration { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INTERVAL_ARRAY + } +} + +impl Encode<'_, Postgres> for std::time::Duration { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + PgInterval::try_from(*self)?.encode_by_ref(buf) + } + + fn size_hint(&self) -> usize { + 2 * mem::size_of::() + } +} + +impl TryFrom for PgInterval { + type Error = BoxDynError; + + /// Convert a `std::time::Duration` to a `PgInterval` + /// + /// This returns an error if there is a loss of precision using nanoseconds or if there is a + /// microsecond overflow. + fn try_from(value: std::time::Duration) -> Result { + if value.as_nanos() % 1000 != 0 { + return Err("PostgreSQL `INTERVAL` does not support nanoseconds precision".into()); + } + + Ok(Self { + months: 0, + days: 0, + microseconds: value.as_micros().try_into()?, + }) + } +} + +#[cfg(feature = "chrono")] +impl Type for chrono::Duration { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INTERVAL + } +} + +#[cfg(feature = "chrono")] +impl PgHasArrayType for chrono::Duration { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INTERVAL_ARRAY + } +} + +#[cfg(feature = "chrono")] +impl Encode<'_, Postgres> for chrono::Duration { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + let pg_interval = PgInterval::try_from(*self)?; + pg_interval.encode_by_ref(buf) + } + + fn size_hint(&self) -> usize { + 2 * mem::size_of::() + } +} + +#[cfg(feature = "chrono")] +impl TryFrom for PgInterval { + type Error = BoxDynError; + + /// Convert a `chrono::Duration` to a `PgInterval`. + /// + /// This returns an error if there is a loss of precision using nanoseconds or if there is a + /// nanosecond overflow. + fn try_from(value: chrono::Duration) -> Result { + value + .num_nanoseconds() + .map_or::, _>( + Err("Overflow has occurred for PostgreSQL `INTERVAL`".into()), + |nanoseconds| { + if nanoseconds % 1000 != 0 { + return Err( + "PostgreSQL `INTERVAL` does not support nanoseconds precision".into(), + ); + } + Ok(()) + }, + )?; + + value.num_microseconds().map_or( + Err("Overflow has occurred for PostgreSQL `INTERVAL`".into()), + |microseconds| { + Ok(Self { + months: 0, + days: 0, + microseconds, + }) + }, + ) + } +} + +#[cfg(feature = "time")] +impl Type for time::Duration { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INTERVAL + } +} + +#[cfg(feature = "time")] +impl PgHasArrayType for time::Duration { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INTERVAL_ARRAY + } +} + +#[cfg(feature = "time")] +impl Encode<'_, Postgres> for time::Duration { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + let pg_interval = PgInterval::try_from(*self)?; + pg_interval.encode_by_ref(buf) + } + + fn size_hint(&self) -> usize { + 2 * mem::size_of::() + } +} + +#[cfg(feature = "time")] +impl TryFrom for PgInterval { + type Error = BoxDynError; + + /// Convert a `time::Duration` to a `PgInterval`. + /// + /// This returns an error if there is a loss of precision using nanoseconds or if there is a + /// microsecond overflow. + fn try_from(value: time::Duration) -> Result { + if value.whole_nanoseconds() % 1000 != 0 { + return Err("PostgreSQL `INTERVAL` does not support nanoseconds precision".into()); + } + + Ok(Self { + months: 0, + days: 0, + microseconds: value.whole_microseconds().try_into()?, + }) + } +} + +#[test] +fn test_encode_interval() { + let mut buf = PgArgumentBuffer::default(); + + let interval = PgInterval { + months: 0, + days: 0, + microseconds: 0, + }; + assert!(matches!( + Encode::::encode(&interval, &mut buf), + Ok(IsNull::No) + )); + assert_eq!(&**buf, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + buf.clear(); + + let interval = PgInterval { + months: 0, + days: 0, + microseconds: 1_000, + }; + assert!(matches!( + Encode::::encode(&interval, &mut buf), + Ok(IsNull::No) + )); + assert_eq!(&**buf, [0, 0, 0, 0, 0, 0, 3, 232, 0, 0, 0, 0, 0, 0, 0, 0]); + buf.clear(); + + let interval = PgInterval { + months: 0, + days: 0, + microseconds: 1_000_000, + }; + assert!(matches!( + Encode::::encode(&interval, &mut buf), + Ok(IsNull::No) + )); + assert_eq!(&**buf, [0, 0, 0, 0, 0, 15, 66, 64, 0, 0, 0, 0, 0, 0, 0, 0]); + buf.clear(); + + let interval = PgInterval { + months: 0, + days: 0, + microseconds: 3_600_000_000, + }; + assert!(matches!( + Encode::::encode(&interval, &mut buf), + Ok(IsNull::No) + )); + assert_eq!( + &**buf, + [0, 0, 0, 0, 214, 147, 164, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + buf.clear(); + + let interval = PgInterval { + months: 0, + days: 1, + microseconds: 0, + }; + assert!(matches!( + Encode::::encode(&interval, &mut buf), + Ok(IsNull::No) + )); + assert_eq!(&**buf, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]); + buf.clear(); + + let interval = PgInterval { + months: 1, + days: 0, + microseconds: 0, + }; + assert!(matches!( + Encode::::encode(&interval, &mut buf), + Ok(IsNull::No) + )); + assert_eq!(&**buf, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); + buf.clear(); + + assert_eq!( + PgInterval::default(), + PgInterval { + months: 0, + days: 0, + microseconds: 0, + } + ); +} + +#[test] +fn test_pginterval_std() { + // Case for positive duration + let interval = PgInterval { + days: 0, + months: 0, + microseconds: 27_000, + }; + assert_eq!( + &PgInterval::try_from(std::time::Duration::from_micros(27_000)).unwrap(), + &interval + ); + + // Case when precision loss occurs + assert!(PgInterval::try_from(std::time::Duration::from_nanos(27_000_001)).is_err()); + + // Case when microsecond overflow occurs + assert!(PgInterval::try_from(std::time::Duration::from_secs(20_000_000_000_000)).is_err()); +} + +#[test] +#[cfg(feature = "chrono")] +fn test_pginterval_chrono() { + // Case for positive duration + let interval = PgInterval { + days: 0, + months: 0, + microseconds: 27_000, + }; + assert_eq!( + &PgInterval::try_from(chrono::Duration::microseconds(27_000)).unwrap(), + &interval + ); + + // Case for negative duration + let interval = PgInterval { + days: 0, + months: 0, + microseconds: -27_000, + }; + assert_eq!( + &PgInterval::try_from(chrono::Duration::microseconds(-27_000)).unwrap(), + &interval + ); + + // Case when precision loss occurs + assert!(PgInterval::try_from(chrono::Duration::nanoseconds(27_000_001)).is_err()); + assert!(PgInterval::try_from(chrono::Duration::nanoseconds(-27_000_001)).is_err()); + + // Case when nanosecond overflow occurs + assert!(PgInterval::try_from(chrono::Duration::seconds(10_000_000_000)).is_err()); + assert!(PgInterval::try_from(chrono::Duration::seconds(-10_000_000_000)).is_err()); +} + +#[test] +#[cfg(feature = "time")] +fn test_pginterval_time() { + // Case for positive duration + let interval = PgInterval { + days: 0, + months: 0, + microseconds: 27_000, + }; + assert_eq!( + &PgInterval::try_from(time::Duration::microseconds(27_000)).unwrap(), + &interval + ); + + // Case for negative duration + let interval = PgInterval { + days: 0, + months: 0, + microseconds: -27_000, + }; + assert_eq!( + &PgInterval::try_from(time::Duration::microseconds(-27_000)).unwrap(), + &interval + ); + + // Case when precision loss occurs + assert!(PgInterval::try_from(time::Duration::nanoseconds(27_000_001)).is_err()); + assert!(PgInterval::try_from(time::Duration::nanoseconds(-27_000_001)).is_err()); + + // Case when microsecond overflow occurs + assert!(PgInterval::try_from(time::Duration::seconds(10_000_000_000_000)).is_err()); + assert!(PgInterval::try_from(time::Duration::seconds(-10_000_000_000_000)).is_err()); +} diff --git a/patches/sqlx-postgres/src/types/ipaddr.rs b/patches/sqlx-postgres/src/types/ipaddr.rs new file mode 100644 index 000000000..ee587eda1 --- /dev/null +++ b/patches/sqlx-postgres/src/types/ipaddr.rs @@ -0,0 +1,62 @@ +use std::net::IpAddr; + +use ipnetwork::IpNetwork; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres}; + +impl Type for IpAddr +where + IpNetwork: Type, +{ + fn type_info() -> PgTypeInfo { + IpNetwork::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + IpNetwork::compatible(ty) + } +} + +impl PgHasArrayType for IpAddr { + fn array_type_info() -> PgTypeInfo { + ::array_type_info() + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + ::array_compatible(ty) + } +} + +impl<'db> Encode<'db, Postgres> for IpAddr +where + IpNetwork: Encode<'db, Postgres>, +{ + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + IpNetwork::from(*self).encode_by_ref(buf) + } + + fn size_hint(&self) -> usize { + IpNetwork::from(*self).size_hint() + } +} + +impl<'db> Decode<'db, Postgres> for IpAddr +where + IpNetwork: Decode<'db, Postgres>, +{ + fn decode(value: PgValueRef<'db>) -> Result { + let ipnetwork = IpNetwork::decode(value)?; + + if ipnetwork.is_ipv4() && ipnetwork.prefix() != 32 + || ipnetwork.is_ipv6() && ipnetwork.prefix() != 128 + { + Err("lossy decode from inet/cidr")? + } + + Ok(ipnetwork.ip()) + } +} diff --git a/patches/sqlx-postgres/src/types/ipnetwork.rs b/patches/sqlx-postgres/src/types/ipnetwork.rs new file mode 100644 index 000000000..4f619ba99 --- /dev/null +++ b/patches/sqlx-postgres/src/types/ipnetwork.rs @@ -0,0 +1,122 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; + +use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network}; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +// https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/include/utils/inet.h#L39 + +// Technically this is a magic number here but it doesn't make sense to drag in the whole of `libc` +// just for one constant. +const PGSQL_AF_INET: u8 = 2; // AF_INET +const PGSQL_AF_INET6: u8 = PGSQL_AF_INET + 1; + +impl Type for IpNetwork { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INET + } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::CIDR || *ty == PgTypeInfo::INET + } +} + +impl PgHasArrayType for IpNetwork { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INET_ARRAY + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::CIDR_ARRAY || *ty == PgTypeInfo::INET_ARRAY + } +} + +impl Encode<'_, Postgres> for IpNetwork { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L293 + // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L271 + + match self { + IpNetwork::V4(net) => { + buf.push(PGSQL_AF_INET); // ip_family + buf.push(net.prefix()); // ip_bits + buf.push(0); // is_cidr + buf.push(4); // nb (number of bytes) + buf.extend_from_slice(&net.ip().octets()) // address + } + + IpNetwork::V6(net) => { + buf.push(PGSQL_AF_INET6); // ip_family + buf.push(net.prefix()); // ip_bits + buf.push(0); // is_cidr + buf.push(16); // nb (number of bytes) + buf.extend_from_slice(&net.ip().octets()); // address + } + } + + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + match self { + IpNetwork::V4(_) => 8, + IpNetwork::V6(_) => 20, + } + } +} + +impl Decode<'_, Postgres> for IpNetwork { + fn decode(value: PgValueRef<'_>) -> Result { + let bytes = match value.format() { + PgValueFormat::Binary => value.as_bytes()?, + PgValueFormat::Text => { + return Ok(value.as_str()?.parse()?); + } + }; + + if bytes.len() >= 8 { + let family = bytes[0]; + let prefix = bytes[1]; + let _is_cidr = bytes[2] != 0; + let len = bytes[3]; + + match family { + PGSQL_AF_INET => { + if bytes.len() == 8 && len == 4 { + let inet = Ipv4Network::new( + Ipv4Addr::new(bytes[4], bytes[5], bytes[6], bytes[7]), + prefix, + )?; + + return Ok(IpNetwork::V4(inet)); + } + } + + PGSQL_AF_INET6 => { + if bytes.len() == 20 && len == 16 { + let inet = Ipv6Network::new( + Ipv6Addr::from([ + bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9], + bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15], + bytes[16], bytes[17], bytes[18], bytes[19], + ]), + prefix, + )?; + + return Ok(IpNetwork::V6(inet)); + } + } + + _ => { + return Err(format!("unknown ip family {family}").into()); + } + } + } + + Err("invalid data received when expecting an INET".into()) + } +} diff --git a/patches/sqlx-postgres/src/types/json.rs b/patches/sqlx-postgres/src/types/json.rs new file mode 100644 index 000000000..567e48015 --- /dev/null +++ b/patches/sqlx-postgres/src/types/json.rs @@ -0,0 +1,99 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::array_compatible; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use serde::{Deserialize, Serialize}; +use serde_json::value::RawValue as JsonRawValue; +use serde_json::Value as JsonValue; +pub(crate) use sqlx_core::types::{Json, Type}; + +// + +// In general, most applications should prefer to store JSON data as jsonb, +// unless there are quite specialized needs, such as legacy assumptions +// about ordering of object keys. + +impl Type for Json { + fn type_info() -> PgTypeInfo { + PgTypeInfo::JSONB + } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::JSON || *ty == PgTypeInfo::JSONB + } +} + +impl PgHasArrayType for Json { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::JSONB_ARRAY + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + array_compatible::>(ty) + } +} + +impl PgHasArrayType for JsonValue { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::JSONB_ARRAY + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + array_compatible::(ty) + } +} + +impl PgHasArrayType for JsonRawValue { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::JSONB_ARRAY + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + array_compatible::(ty) + } +} + +impl<'q, T> Encode<'q, Postgres> for Json +where + T: Serialize, +{ + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + // we have a tiny amount of dynamic behavior depending if we are resolved to be JSON + // instead of JSONB + buf.patch(|buf, ty: &PgTypeInfo| { + if *ty == PgTypeInfo::JSON || *ty == PgTypeInfo::JSON_ARRAY { + buf[0] = b' '; + } + }); + + // JSONB version (as of 2020-03-20) + buf.push(1); + + // the JSON data written to the buffer is the same regardless of parameter type + serde_json::to_writer(&mut **buf, &self.0)?; + + Ok(IsNull::No) + } +} + +impl<'r, T: 'r> Decode<'r, Postgres> for Json +where + T: Deserialize<'r>, +{ + fn decode(value: PgValueRef<'r>) -> Result { + let mut buf = value.as_bytes()?; + + if value.format() == PgValueFormat::Binary && value.type_info == PgTypeInfo::JSONB { + assert_eq!( + buf[0], 1, + "unsupported JSONB format version {}; please open an issue", + buf[0] + ); + + buf = &buf[1..]; + } + + serde_json::from_slice(buf).map(Json).map_err(Into::into) + } +} diff --git a/patches/sqlx-postgres/src/types/lquery.rs b/patches/sqlx-postgres/src/types/lquery.rs new file mode 100644 index 000000000..faed95751 --- /dev/null +++ b/patches/sqlx-postgres/src/types/lquery.rs @@ -0,0 +1,335 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use bitflags::bitflags; +use std::fmt::{self, Display, Formatter}; +use std::io::Write; +use std::ops::Deref; +use std::str::FromStr; + +use crate::types::ltree::{PgLTreeLabel, PgLTreeParseError}; + +/// Represents lquery specific errors +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum PgLQueryParseError { + #[error("lquery cannot be empty")] + EmptyString, + #[error("unexpected character in lquery")] + UnexpectedCharacter, + #[error("error parsing integer: {0}")] + ParseIntError(#[from] std::num::ParseIntError), + #[error("error parsing integer: {0}")] + LTreeParrseError(#[from] PgLTreeParseError), + /// LQuery version not supported + #[error("lquery version not supported")] + InvalidLqueryVersion, +} + +/// Container for a Label Tree Query (`lquery`) in Postgres. +/// +/// See +/// +/// ### Note: Requires Postgres 13+ +/// +/// This integration requires that the `lquery` type support the binary format in the Postgres +/// wire protocol, which only became available in Postgres 13. +/// ([Postgres 13.0 Release Notes, Additional Modules](https://www.postgresql.org/docs/13/release-13.html#id-1.11.6.11.5.14)) +/// +/// Ideally, SQLx's Postgres driver should support falling back to text format for types +/// which don't have `typsend` and `typrecv` entries in `pg_type`, but that work still needs +/// to be done. +/// +/// ### Note: Extension Required +/// The `ltree` extension is not enabled by default in Postgres. You will need to do so explicitly: +/// +/// ```ignore +/// CREATE EXTENSION IF NOT EXISTS "ltree"; +/// ``` +#[derive(Clone, Debug, Default, PartialEq)] +pub struct PgLQuery { + levels: Vec, +} + +// TODO: maybe a QueryBuilder pattern would be nice here +impl PgLQuery { + /// creates default/empty lquery + pub fn new() -> Self { + Self::default() + } + + pub fn from(levels: Vec) -> Self { + Self { levels } + } + + /// push a query level + pub fn push(&mut self, level: PgLQueryLevel) { + self.levels.push(level); + } + + /// pop a query level + pub fn pop(&mut self) -> Option { + self.levels.pop() + } + + /// creates lquery from an iterator with checking labels + // TODO: this should just be removed but I didn't want to bury it in a massive diff + #[deprecated = "renamed to `try_from_iter()`"] + #[allow(clippy::should_implement_trait)] + pub fn from_iter(levels: I) -> Result + where + S: Into, + I: IntoIterator, + { + let mut lquery = Self::default(); + for level in levels { + lquery.push(PgLQueryLevel::from_str(&level.into())?); + } + Ok(lquery) + } + + /// Create an `LQUERY` from an iterator of label strings. + /// + /// Returns an error if any label fails to parse according to [`PgLQueryLevel::from_str()`]. + pub fn try_from_iter(levels: I) -> Result + where + S: AsRef, + I: IntoIterator, + { + levels + .into_iter() + .map(|level| level.as_ref().parse::()) + .collect() + } +} + +impl FromIterator for PgLQuery { + fn from_iter>(iter: T) -> Self { + Self::from(iter.into_iter().collect()) + } +} + +impl IntoIterator for PgLQuery { + type Item = PgLQueryLevel; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.levels.into_iter() + } +} + +impl FromStr for PgLQuery { + type Err = PgLQueryParseError; + + fn from_str(s: &str) -> Result { + Ok(Self { + levels: s + .split('.') + .map(PgLQueryLevel::from_str) + .collect::>()?, + }) + } +} + +impl Display for PgLQuery { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut iter = self.levels.iter(); + if let Some(label) = iter.next() { + write!(f, "{label}")?; + for label in iter { + write!(f, ".{label}")?; + } + } + Ok(()) + } +} + +impl Deref for PgLQuery { + type Target = [PgLQueryLevel]; + + fn deref(&self) -> &Self::Target { + &self.levels + } +} + +impl Type for PgLQuery { + fn type_info() -> PgTypeInfo { + // Since `ltree` is enabled by an extension, it does not have a stable OID. + PgTypeInfo::with_name("lquery") + } +} + +impl Encode<'_, Postgres> for PgLQuery { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend(1i8.to_le_bytes()); + write!(buf, "{self}")?; + + Ok(IsNull::No) + } +} + +impl<'r> Decode<'r, Postgres> for PgLQuery { + fn decode(value: PgValueRef<'r>) -> Result { + match value.format() { + PgValueFormat::Binary => { + let bytes = value.as_bytes()?; + let version = i8::from_le_bytes([bytes[0]; 1]); + if version != 1 { + return Err(Box::new(PgLQueryParseError::InvalidLqueryVersion)); + } + Ok(Self::from_str(std::str::from_utf8(&bytes[1..])?)?) + } + PgValueFormat::Text => Ok(Self::from_str(value.as_str()?)?), + } + } +} + +bitflags! { + /// Modifiers that can be set to non-star labels + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + pub struct PgLQueryVariantFlag: u16 { + /// * - Match any label with this prefix, for example foo* matches foobar + const ANY_END = 0x01; + /// @ - Match case-insensitively, for example a@ matches A + const IN_CASE = 0x02; + /// % - Match initial underscore-separated words + const SUBLEXEME = 0x04; + } +} + +impl Display for PgLQueryVariantFlag { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if self.contains(PgLQueryVariantFlag::ANY_END) { + write!(f, "*")?; + } + if self.contains(PgLQueryVariantFlag::IN_CASE) { + write!(f, "@")?; + } + if self.contains(PgLQueryVariantFlag::SUBLEXEME) { + write!(f, "%")?; + } + + Ok(()) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct PgLQueryVariant { + label: PgLTreeLabel, + modifiers: PgLQueryVariantFlag, +} + +impl Display for PgLQueryVariant { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}{}", self.label, self.modifiers) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum PgLQueryLevel { + /// match any label (*) with optional at least / at most numbers + Star(Option, Option), + /// match any of specified labels with optional flags + NonStar(Vec), + /// match none of specified labels with optional flags + NotNonStar(Vec), +} + +impl FromStr for PgLQueryLevel { + type Err = PgLQueryParseError; + + fn from_str(s: &str) -> Result { + let bytes = s.as_bytes(); + if bytes.is_empty() { + Err(PgLQueryParseError::EmptyString) + } else { + match bytes[0] { + b'*' => { + if bytes.len() > 1 { + let parts = s[2..s.len() - 1].split(',').collect::>(); + match parts.len() { + 1 => { + let number = parts[0].parse()?; + Ok(PgLQueryLevel::Star(Some(number), Some(number))) + } + 2 => Ok(PgLQueryLevel::Star( + Some(parts[0].parse()?), + Some(parts[1].parse()?), + )), + _ => Err(PgLQueryParseError::UnexpectedCharacter), + } + } else { + Ok(PgLQueryLevel::Star(None, None)) + } + } + b'!' => Ok(PgLQueryLevel::NotNonStar( + s[1..] + .split('|') + .map(PgLQueryVariant::from_str) + .collect::, PgLQueryParseError>>()?, + )), + _ => Ok(PgLQueryLevel::NonStar( + s.split('|') + .map(PgLQueryVariant::from_str) + .collect::, PgLQueryParseError>>()?, + )), + } + } + } +} + +impl FromStr for PgLQueryVariant { + type Err = PgLQueryParseError; + + fn from_str(s: &str) -> Result { + let mut label_length = s.len(); + let mut modifiers = PgLQueryVariantFlag::empty(); + + for b in s.bytes().rev() { + match b { + b'@' => modifiers.insert(PgLQueryVariantFlag::IN_CASE), + b'*' => modifiers.insert(PgLQueryVariantFlag::ANY_END), + b'%' => modifiers.insert(PgLQueryVariantFlag::SUBLEXEME), + _ => break, + } + label_length -= 1; + } + + Ok(PgLQueryVariant { + label: PgLTreeLabel::new(&s[0..label_length])?, + modifiers, + }) + } +} + +fn write_variants(f: &mut Formatter<'_>, variants: &[PgLQueryVariant], not: bool) -> fmt::Result { + let mut iter = variants.iter(); + if let Some(variant) = iter.next() { + write!(f, "{}{}", if not { "!" } else { "" }, variant)?; + for variant in iter { + write!(f, ".{variant}")?; + } + } + Ok(()) +} + +impl Display for PgLQueryLevel { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + PgLQueryLevel::Star(Some(at_least), Some(at_most)) => { + if at_least == at_most { + write!(f, "*{{{at_least}}}") + } else { + write!(f, "*{{{at_least},{at_most}}}") + } + } + PgLQueryLevel::Star(Some(at_least), _) => write!(f, "*{{{at_least},}}"), + PgLQueryLevel::Star(_, Some(at_most)) => write!(f, "*{{,{at_most}}}"), + PgLQueryLevel::Star(_, _) => write!(f, "*"), + PgLQueryLevel::NonStar(variants) => write_variants(f, variants, false), + PgLQueryLevel::NotNonStar(variants) => write_variants(f, variants, true), + } + } +} diff --git a/patches/sqlx-postgres/src/types/ltree.rs b/patches/sqlx-postgres/src/types/ltree.rs new file mode 100644 index 000000000..531f50656 --- /dev/null +++ b/patches/sqlx-postgres/src/types/ltree.rs @@ -0,0 +1,228 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use std::fmt::{self, Display, Formatter}; +use std::io::Write; +use std::ops::Deref; +use std::str::FromStr; + +/// Represents ltree specific errors +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum PgLTreeParseError { + /// LTree labels can only contain [A-Za-z0-9_] + #[error("ltree label contains invalid characters")] + InvalidLtreeLabel, + + /// LTree version not supported + #[error("ltree version not supported")] + InvalidLtreeVersion, +} + +#[derive(Clone, Debug, Default, PartialEq)] +pub struct PgLTreeLabel(String); + +impl PgLTreeLabel { + pub fn new(label: S) -> Result + where + S: Into, + { + let label = label.into(); + if label.len() <= 256 + && label + .bytes() + .all(|c| c.is_ascii_alphabetic() || c.is_ascii_digit() || c == b'_') + { + Ok(Self(label)) + } else { + Err(PgLTreeParseError::InvalidLtreeLabel) + } + } +} + +impl Deref for PgLTreeLabel { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.0.as_str() + } +} + +impl FromStr for PgLTreeLabel { + type Err = PgLTreeParseError; + + fn from_str(s: &str) -> Result { + PgLTreeLabel::new(s) + } +} + +impl Display for PgLTreeLabel { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Container for a Label Tree (`ltree`) in Postgres. +/// +/// See +/// +/// ### Note: Requires Postgres 13+ +/// +/// This integration requires that the `ltree` type support the binary format in the Postgres +/// wire protocol, which only became available in Postgres 13. +/// ([Postgres 13.0 Release Notes, Additional Modules](https://www.postgresql.org/docs/13/release-13.html#id-1.11.6.11.5.14)) +/// +/// Ideally, SQLx's Postgres driver should support falling back to text format for types +/// which don't have `typsend` and `typrecv` entries in `pg_type`, but that work still needs +/// to be done. +/// +/// ### Note: Extension Required +/// The `ltree` extension is not enabled by default in Postgres. You will need to do so explicitly: +/// +/// ```ignore +/// CREATE EXTENSION IF NOT EXISTS "ltree"; +/// ``` +#[derive(Clone, Debug, Default, PartialEq)] +pub struct PgLTree { + labels: Vec, +} + +impl PgLTree { + /// creates default/empty ltree + pub fn new() -> Self { + Self::default() + } + + /// creates ltree from a [`Vec`] + pub fn from(labels: Vec) -> Self { + Self { labels } + } + + /// creates ltree from an iterator with checking labels + // TODO: this should just be removed but I didn't want to bury it in a massive diff + #[deprecated = "renamed to `try_from_iter()`"] + #[allow(clippy::should_implement_trait)] + pub fn from_iter(labels: I) -> Result + where + String: From, + I: IntoIterator, + { + let mut ltree = Self::default(); + for label in labels { + ltree.push(PgLTreeLabel::new(label)?); + } + Ok(ltree) + } + + /// Create an `LTREE` from an iterator of label strings. + /// + /// Returns an error if any label fails to parse according to [`PgLTreeLabel::new()`]. + pub fn try_from_iter(labels: I) -> Result + where + S: Into, + I: IntoIterator, + { + labels.into_iter().map(PgLTreeLabel::new).collect() + } + + /// push a label to ltree + pub fn push(&mut self, label: PgLTreeLabel) { + self.labels.push(label); + } + + /// pop a label from ltree + pub fn pop(&mut self) -> Option { + self.labels.pop() + } +} + +impl FromIterator for PgLTree { + fn from_iter>(iter: T) -> Self { + Self { + labels: iter.into_iter().collect(), + } + } +} + +impl IntoIterator for PgLTree { + type Item = PgLTreeLabel; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.labels.into_iter() + } +} + +impl FromStr for PgLTree { + type Err = PgLTreeParseError; + + fn from_str(s: &str) -> Result { + Ok(Self { + labels: s + .split('.') + .map(PgLTreeLabel::new) + .collect::, Self::Err>>()?, + }) + } +} + +impl Display for PgLTree { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut iter = self.labels.iter(); + if let Some(label) = iter.next() { + write!(f, "{label}")?; + for label in iter { + write!(f, ".{label}")?; + } + } + Ok(()) + } +} + +impl Deref for PgLTree { + type Target = [PgLTreeLabel]; + + fn deref(&self) -> &Self::Target { + &self.labels + } +} + +impl Type for PgLTree { + fn type_info() -> PgTypeInfo { + // Since `ltree` is enabled by an extension, it does not have a stable OID. + PgTypeInfo::with_name("ltree") + } +} + +impl PgHasArrayType for PgLTree { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_ltree") + } +} + +impl Encode<'_, Postgres> for PgLTree { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend(1i8.to_le_bytes()); + write!(buf, "{self}")?; + + Ok(IsNull::No) + } +} + +impl<'r> Decode<'r, Postgres> for PgLTree { + fn decode(value: PgValueRef<'r>) -> Result { + match value.format() { + PgValueFormat::Binary => { + let bytes = value.as_bytes()?; + let version = i8::from_le_bytes([bytes[0]; 1]); + if version != 1 { + return Err(Box::new(PgLTreeParseError::InvalidLtreeVersion)); + } + Ok(Self::from_str(std::str::from_utf8(&bytes[1..])?)?) + } + PgValueFormat::Text => Ok(Self::from_str(value.as_str()?)?), + } + } +} diff --git a/patches/sqlx-postgres/src/types/mac_address.rs b/patches/sqlx-postgres/src/types/mac_address.rs new file mode 100644 index 000000000..23766e700 --- /dev/null +++ b/patches/sqlx-postgres/src/types/mac_address.rs @@ -0,0 +1,51 @@ +use mac_address::MacAddress; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +impl Type for MacAddress { + fn type_info() -> PgTypeInfo { + PgTypeInfo::MACADDR + } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::MACADDR + } +} + +impl PgHasArrayType for MacAddress { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::MACADDR_ARRAY + } +} + +impl Encode<'_, Postgres> for MacAddress { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend_from_slice(&self.bytes()); // write just the address + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + 6 + } +} + +impl Decode<'_, Postgres> for MacAddress { + fn decode(value: PgValueRef<'_>) -> Result { + let bytes = match value.format() { + PgValueFormat::Binary => value.as_bytes()?, + PgValueFormat::Text => { + return Ok(value.as_str()?.parse()?); + } + }; + + if bytes.len() == 6 { + return Ok(MacAddress::new(bytes.try_into().unwrap())); + } + + Err("invalid data received when expecting an MACADDR".into()) + } +} diff --git a/patches/sqlx-postgres/src/types/mod.rs b/patches/sqlx-postgres/src/types/mod.rs new file mode 100644 index 000000000..846f1b731 --- /dev/null +++ b/patches/sqlx-postgres/src/types/mod.rs @@ -0,0 +1,275 @@ +//! Conversions between Rust and **Postgres** types. +//! +//! # Types +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `bool` | BOOL | +//! | `i8` | "CHAR" | +//! | `i16` | SMALLINT, SMALLSERIAL, INT2 | +//! | `i32` | INT, SERIAL, INT4 | +//! | `i64` | BIGINT, BIGSERIAL, INT8 | +//! | `f32` | REAL, FLOAT4 | +//! | `f64` | DOUBLE PRECISION, FLOAT8 | +//! | `&str`, [`String`] | VARCHAR, CHAR(N), TEXT, NAME, CITEXT | +//! | `&[u8]`, `Vec` | BYTEA | +//! | `()` | VOID | +//! | [`PgInterval`] | INTERVAL | +//! | [`PgRange`](PgRange) | INT8RANGE, INT4RANGE, TSRANGE, TSTZRANGE, DATERANGE, NUMRANGE | +//! | [`PgMoney`] | MONEY | +//! | [`PgLTree`] | LTREE | +//! | [`PgLQuery`] | LQUERY | +//! | [`PgCiText`] | CITEXT1 | +//! | [`PgCube`] | CUBE | +//! | [`PgHstore`] | HSTORE | +//! +//! 1 SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc., +//! but this wrapper type is available for edge cases, such as `CITEXT[]` which Postgres +//! does not consider to be compatible with `TEXT[]`. +//! +//! ### [`bigdecimal`](https://crates.io/crates/bigdecimal) +//! Requires the `bigdecimal` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `bigdecimal::BigDecimal` | NUMERIC | +//! +#![doc=include_str!("bigdecimal-range.md")] +//! +//! ### [`rust_decimal`](https://crates.io/crates/rust_decimal) +//! Requires the `rust_decimal` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `rust_decimal::Decimal` | NUMERIC | +//! +#![doc=include_str!("rust_decimal-range.md")] +//! +//! ### [`chrono`](https://crates.io/crates/chrono) +//! +//! Requires the `chrono` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `chrono::DateTime` | TIMESTAMPTZ | +//! | `chrono::DateTime` | TIMESTAMPTZ | +//! | `chrono::NaiveDateTime` | TIMESTAMP | +//! | `chrono::NaiveDate` | DATE | +//! | `chrono::NaiveTime` | TIME | +//! | [`PgTimeTz`] | TIMETZ | +//! +//! ### [`time`](https://crates.io/crates/time) +//! +//! Requires the `time` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `time::PrimitiveDateTime` | TIMESTAMP | +//! | `time::OffsetDateTime` | TIMESTAMPTZ | +//! | `time::Date` | DATE | +//! | `time::Time` | TIME | +//! | [`PgTimeTz`] | TIMETZ | +//! +//! ### [`uuid`](https://crates.io/crates/uuid) +//! +//! Requires the `uuid` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `uuid::Uuid` | UUID | +//! +//! ### [`ipnetwork`](https://crates.io/crates/ipnetwork) +//! +//! Requires the `ipnetwork` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `ipnetwork::IpNetwork` | INET, CIDR | +//! | `std::net::IpAddr` | INET, CIDR | +//! +//! Note that because `IpAddr` does not support network prefixes, it is an error to attempt to decode +//! an `IpAddr` from a `INET` or `CIDR` value with a network prefix smaller than the address' full width: +//! `/32` for IPv4 addresses and `/128` for IPv6 addresses. +//! +//! `IpNetwork` does not have this limitation. +//! +//! ### [`mac_address`](https://crates.io/crates/mac_address) +//! +//! Requires the `mac_address` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `mac_address::MacAddress` | MACADDR | +//! +//! ### [`bit-vec`](https://crates.io/crates/bit-vec) +//! +//! Requires the `bit-vec` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `bit_vec::BitVec` | BIT, VARBIT | +//! +//! ### [`json`](https://crates.io/crates/serde_json) +//! +//! Requires the `json` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | [`Json`] | JSON, JSONB | +//! | `serde_json::Value` | JSON, JSONB | +//! | `&serde_json::value::RawValue` | JSON, JSONB | +//! +//! `Value` and `RawValue` from `serde_json` can be used for unstructured JSON data with +//! Postgres. +//! +//! [`Json`](crate::types::Json) can be used for structured JSON data with Postgres. +//! +//! # [Composite types](https://www.postgresql.org/docs/current/rowtypes.html) +//! +//! User-defined composite types are supported through a derive for `Type`. +//! +//! ```text +//! CREATE TYPE inventory_item AS ( +//! name text, +//! supplier_id integer, +//! price numeric +//! ); +//! ``` +//! +//! ```rust,ignore +//! #[derive(sqlx::Type)] +//! #[sqlx(type_name = "inventory_item")] +//! struct InventoryItem { +//! name: String, +//! supplier_id: i32, +//! price: BigDecimal, +//! } +//! ``` +//! +//! Anonymous composite types are represented as tuples. Note that anonymous composites may only +//! be returned and not sent to Postgres (this is a limitation of postgres). +//! +//! # Arrays +//! +//! One-dimensional arrays are supported as `Vec` or `&[T]` where `T` implements `Type`. +//! +//! # [Enumerations](https://www.postgresql.org/docs/current/datatype-enum.html) +//! +//! User-defined enumerations are supported through a derive for `Type`. +//! +//! ```text +//! CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy'); +//! ``` +//! +//! ```rust,ignore +//! #[derive(sqlx::Type)] +//! #[sqlx(type_name = "mood", rename_all = "lowercase")] +//! enum Mood { Sad, Ok, Happy } +//! ``` +//! +//! Rust enumerations may also be defined to be represented as an integer using `repr`. +//! The following type expects a SQL type of `INTEGER` or `INT4` and will convert to/from the +//! Rust enumeration. +//! +//! ```rust,ignore +//! #[derive(sqlx::Type)] +//! #[repr(i32)] +//! enum Mood { Sad = 0, Ok = 1, Happy = 2 } +//! ``` +//! + +use crate::type_info::PgTypeKind; +use crate::{PgTypeInfo, Postgres}; + +pub(crate) use sqlx_core::types::{Json, Type}; + +mod array; +mod bool; +mod bytes; +mod citext; +mod float; +mod hstore; +mod int; +mod interval; +mod lquery; +mod ltree; +// Not behind a Cargo feature because we require JSON in the driver implementation. +mod json; +mod money; +mod oid; +mod range; +mod record; +mod str; +mod text; +mod tuple; +mod void; + +#[cfg(any(feature = "chrono", feature = "time"))] +mod time_tz; + +#[cfg(feature = "bigdecimal")] +mod bigdecimal; + +mod cube; + +#[cfg(any(feature = "bigdecimal", feature = "rust_decimal"))] +mod numeric; + +#[cfg(feature = "rust_decimal")] +mod rust_decimal; + +#[cfg(feature = "chrono")] +mod chrono; + +#[cfg(feature = "time")] +mod time; + +#[cfg(feature = "uuid")] +mod uuid; + +#[cfg(feature = "ipnetwork")] +mod ipnetwork; + +#[cfg(feature = "ipnetwork")] +mod ipaddr; + +#[cfg(feature = "mac_address")] +mod mac_address; + +#[cfg(feature = "bit-vec")] +mod bit_vec; + +pub use array::PgHasArrayType; +pub use citext::PgCiText; +pub use cube::PgCube; +pub use hstore::PgHstore; +pub use interval::PgInterval; +pub use lquery::PgLQuery; +pub use lquery::PgLQueryLevel; +pub use lquery::PgLQueryVariant; +pub use lquery::PgLQueryVariantFlag; +pub use ltree::PgLTree; +pub use ltree::PgLTreeLabel; +pub use ltree::PgLTreeParseError; +pub use money::PgMoney; +pub use oid::Oid; +pub use range::PgRange; + +#[cfg(any(feature = "chrono", feature = "time"))] +pub use time_tz::PgTimeTz; + +// used in derive(Type) for `struct` +// but the interface is not considered part of the public API +#[doc(hidden)] +pub use record::{PgRecordDecoder, PgRecordEncoder}; + +// Type::compatible impl appropriate for arrays +fn array_compatible + ?Sized>(ty: &PgTypeInfo) -> bool { + // we require the declared type to be an _array_ with an + // element type that is acceptable + if let PgTypeKind::Array(element) = &ty.kind() { + return E::compatible(element); + } + + false +} diff --git a/patches/sqlx-postgres/src/types/money.rs b/patches/sqlx-postgres/src/types/money.rs new file mode 100644 index 000000000..52fc68795 --- /dev/null +++ b/patches/sqlx-postgres/src/types/money.rs @@ -0,0 +1,365 @@ +use crate::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, + {PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}, +}; +use byteorder::{BigEndian, ByteOrder}; +use std::{ + io, + ops::{Add, AddAssign, Sub, SubAssign}, +}; + +/// The PostgreSQL [`MONEY`] type stores a currency amount with a fixed fractional +/// precision. The fractional precision is determined by the database's +/// `lc_monetary` setting. +/// +/// Data is read and written as 64-bit signed integers, and conversion into a +/// decimal should be done using the right precision. +/// +/// Reading `MONEY` value in text format is not supported and will cause an error. +/// +/// ### `locale_frac_digits` +/// This parameter corresponds to the number of digits after the decimal separator. +/// +/// This value must match what Postgres is expecting for the locale set in the database +/// or else the decimal value you see on the client side will not match the `money` value +/// on the server side. +/// +/// **For _most_ locales, this value is `2`.** +/// +/// If you're not sure what locale your database is set to or how many decimal digits it specifies, +/// you can execute `SHOW lc_monetary;` to get the locale name, and then look it up in this list +/// (you can ignore the `.utf8` prefix): +/// +/// +/// If that link is dead and you're on a POSIX-compliant system (Unix, FreeBSD) you can also execute: +/// +/// ```sh +/// $ LC_MONETARY= locale -k frac_digits +/// ``` +/// +/// And the value you want is `N` in `frac_digits=N`. If you have shell access to the database +/// server you should execute it there as available locales may differ between machines. +/// +/// Note that if `frac_digits` for the locale is outside the range `[0, 10]`, Postgres assumes +/// it's a sentinel value and defaults to 2: +/// +/// +/// [`MONEY`]: https://www.postgresql.org/docs/current/datatype-money.html +#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)] +pub struct PgMoney( + /// The raw integer value sent over the wire; for locales with `frac_digits=2` (i.e. most + /// of them), this will be the value in whole cents. + /// + /// E.g. for `select '$123.45'::money` with a locale of `en_US` (`frac_digits=2`), + /// this will be `12345`. + /// + /// If the currency of your locale does not have fractional units, e.g. Yen, then this will + /// just be the units of the currency. + /// + /// See the type-level docs for an explanation of `locale_frac_units`. + pub i64, +); + +impl PgMoney { + /// Convert the money value into a [`BigDecimal`] using `locale_frac_digits`. + /// + /// See the type-level docs for an explanation of `locale_frac_digits`. + /// + /// [`BigDecimal`]: bigdecimal::BigDecimal + #[cfg(feature = "bigdecimal")] + pub fn to_bigdecimal(self, locale_frac_digits: i64) -> bigdecimal::BigDecimal { + let digits = num_bigint::BigInt::from(self.0); + + bigdecimal::BigDecimal::new(digits, locale_frac_digits) + } + + /// Convert the money value into a [`Decimal`] using `locale_frac_digits`. + /// + /// See the type-level docs for an explanation of `locale_frac_digits`. + /// + /// [`Decimal`]: rust_decimal::Decimal + #[cfg(feature = "rust_decimal")] + pub fn to_decimal(self, locale_frac_digits: u32) -> rust_decimal::Decimal { + rust_decimal::Decimal::new(self.0, locale_frac_digits) + } + + /// Convert a [`Decimal`] value into money using `locale_frac_digits`. + /// + /// See the type-level docs for an explanation of `locale_frac_digits`. + /// + /// Note that `Decimal` has 96 bits of precision, but `PgMoney` only has 63 plus the sign bit. + /// If the value is larger than 63 bits it will be truncated. + /// + /// [`Decimal`]: rust_decimal::Decimal + #[cfg(feature = "rust_decimal")] + pub fn from_decimal(mut decimal: rust_decimal::Decimal, locale_frac_digits: u32) -> Self { + // this is all we need to convert to our expected locale's `frac_digits` + decimal.rescale(locale_frac_digits); + + /// a mask to bitwise-AND with an `i64` to zero the sign bit + const SIGN_MASK: i64 = i64::MAX; + + let is_negative = decimal.is_sign_negative(); + let serialized = decimal.serialize(); + + // interpret bytes `4..12` as an i64, ignoring the sign bit + // this is where truncation occurs + let value = i64::from_le_bytes( + *<&[u8; 8]>::try_from(&serialized[4..12]) + .expect("BUG: slice of serialized should be 8 bytes"), + ) & SIGN_MASK; // zero out the sign bit + + // negate if necessary + Self(if is_negative { -value } else { value }) + } + + /// Convert a [`BigDecimal`](bigdecimal::BigDecimal) value into money using the correct precision + /// defined in the PostgreSQL settings. The default precision is two. + #[cfg(feature = "bigdecimal")] + pub fn from_bigdecimal( + decimal: bigdecimal::BigDecimal, + locale_frac_digits: u32, + ) -> Result { + use bigdecimal::ToPrimitive; + + let multiplier = bigdecimal::BigDecimal::new( + num_bigint::BigInt::from(10i128.pow(locale_frac_digits)), + 0, + ); + + let cents = decimal * multiplier; + + let money = cents.to_i64().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "Provided BigDecimal could not convert to i64: overflow.", + ) + })?; + + Ok(Self(money)) + } +} + +impl Type for PgMoney { + fn type_info() -> PgTypeInfo { + PgTypeInfo::MONEY + } +} + +impl PgHasArrayType for PgMoney { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::MONEY_ARRAY + } +} + +impl From for PgMoney +where + T: Into, +{ + fn from(num: T) -> Self { + Self(num.into()) + } +} + +impl Encode<'_, Postgres> for PgMoney { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend(&self.0.to_be_bytes()); + + Ok(IsNull::No) + } +} + +impl Decode<'_, Postgres> for PgMoney { + fn decode(value: PgValueRef<'_>) -> Result { + match value.format() { + PgValueFormat::Binary => { + let cents = BigEndian::read_i64(value.as_bytes()?); + + Ok(PgMoney(cents)) + } + PgValueFormat::Text => { + let error = io::Error::new( + io::ErrorKind::InvalidData, + "Reading a `MONEY` value in text format is not supported.", + ); + + Err(Box::new(error)) + } + } + } +} + +impl Add for PgMoney { + type Output = PgMoney; + + /// Adds two monetary values. + /// + /// # Panics + /// Panics if overflowing the `i64::MAX`. + fn add(self, rhs: PgMoney) -> Self::Output { + self.0 + .checked_add(rhs.0) + .map(PgMoney) + .expect("overflow adding money amounts") + } +} + +impl AddAssign for PgMoney { + /// An assigning add for two monetary values. + /// + /// # Panics + /// Panics if overflowing the `i64::MAX`. + fn add_assign(&mut self, rhs: PgMoney) { + self.0 = self + .0 + .checked_add(rhs.0) + .expect("overflow adding money amounts") + } +} + +impl Sub for PgMoney { + type Output = PgMoney; + + /// Subtracts two monetary values. + /// + /// # Panics + /// Panics if underflowing the `i64::MIN`. + fn sub(self, rhs: PgMoney) -> Self::Output { + self.0 + .checked_sub(rhs.0) + .map(PgMoney) + .expect("overflow subtracting money amounts") + } +} + +impl SubAssign for PgMoney { + /// An assigning subtract for two monetary values. + /// + /// # Panics + /// Panics if underflowing the `i64::MIN`. + fn sub_assign(&mut self, rhs: PgMoney) { + self.0 = self + .0 + .checked_sub(rhs.0) + .expect("overflow subtracting money amounts") + } +} + +#[cfg(test)] +mod tests { + use super::PgMoney; + + #[test] + fn adding_works() { + assert_eq!(PgMoney(3), PgMoney(1) + PgMoney(2)) + } + + #[test] + fn add_assign_works() { + let mut money = PgMoney(1); + money += PgMoney(2); + + assert_eq!(PgMoney(3), money); + } + + #[test] + fn subtracting_works() { + assert_eq!(PgMoney(4), PgMoney(5) - PgMoney(1)) + } + + #[test] + fn sub_assign_works() { + let mut money = PgMoney(1); + money -= PgMoney(2); + + assert_eq!(PgMoney(-1), money); + } + + #[test] + fn default_value() { + let money = PgMoney::default(); + + assert_eq!(money, PgMoney(0)); + } + + #[test] + #[should_panic] + fn add_overflow_panics() { + let _ = PgMoney(i64::MAX) + PgMoney(1); + } + + #[test] + #[should_panic] + fn add_assign_overflow_panics() { + let mut money = PgMoney(i64::MAX); + money += PgMoney(1); + } + + #[test] + #[should_panic] + fn sub_overflow_panics() { + let _ = PgMoney(i64::MIN) - PgMoney(1); + } + + #[test] + #[should_panic] + fn sub_assign_overflow_panics() { + let mut money = PgMoney(i64::MIN); + money -= PgMoney(1); + } + + #[test] + #[cfg(feature = "bigdecimal")] + fn conversion_to_bigdecimal_works() { + let money = PgMoney(12345); + + assert_eq!( + bigdecimal::BigDecimal::new(num_bigint::BigInt::from(12345), 2), + money.to_bigdecimal(2) + ); + } + + #[test] + #[cfg(feature = "rust_decimal")] + fn conversion_to_decimal_works() { + assert_eq!( + rust_decimal::Decimal::new(12345, 2), + PgMoney(12345).to_decimal(2) + ); + } + + #[test] + #[cfg(feature = "rust_decimal")] + fn conversion_from_decimal_works() { + assert_eq!( + PgMoney(12345), + PgMoney::from_decimal(rust_decimal::Decimal::new(12345, 2), 2) + ); + + assert_eq!( + PgMoney(12345), + PgMoney::from_decimal(rust_decimal::Decimal::new(123450, 3), 2) + ); + + assert_eq!( + PgMoney(-12345), + PgMoney::from_decimal(rust_decimal::Decimal::new(-123450, 3), 2) + ); + + assert_eq!( + PgMoney(-12300), + PgMoney::from_decimal(rust_decimal::Decimal::new(-123, 0), 2) + ); + } + + #[test] + #[cfg(feature = "bigdecimal")] + fn conversion_from_bigdecimal_works() { + let dec = bigdecimal::BigDecimal::new(num_bigint::BigInt::from(12345), 2); + + assert_eq!(PgMoney(12345), PgMoney::from_bigdecimal(dec, 2).unwrap()); + } +} diff --git a/patches/sqlx-postgres/src/types/numeric.rs b/patches/sqlx-postgres/src/types/numeric.rs new file mode 100644 index 000000000..67713d769 --- /dev/null +++ b/patches/sqlx-postgres/src/types/numeric.rs @@ -0,0 +1,172 @@ +use sqlx_core::bytes::Buf; +use std::num::Saturating; + +use crate::error::BoxDynError; +use crate::PgArgumentBuffer; + +/// Represents a `NUMERIC` value in the **Postgres** wire protocol. +#[derive(Debug, PartialEq, Eq)] +pub(crate) enum PgNumeric { + /// Equivalent to the `'NaN'` value in Postgres. The result of, e.g. `1 / 0`. + NotANumber, + + /// A populated `NUMERIC` value. + /// + /// A description of these fields can be found here (although the type being described is the + /// version for in-memory calculations, the field names are the same): + /// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L224-L269 + Number { + /// The sign of the value: positive (also set for 0 and -0), or negative. + sign: PgNumericSign, + + /// The digits of the number in base-10000 with the most significant digit first + /// (big-endian). + /// + /// The length of this vector must not overflow `i16` for the binary protocol. + /// + /// *Note*: the `Encode` implementation will panic if any digit is `>= 10000`. + digits: Vec, + + /// The scaling factor of the number, such that the value will be interpreted as + /// + /// ```text + /// digits[0] * 10,000 ^ weight + /// + digits[1] * 10,000 ^ (weight - 1) + /// ... + /// + digits[N] * 10,000 ^ (weight - N) where N = digits.len() - 1 + /// ``` + /// May be negative. + weight: i16, + + /// How many _decimal_ (base-10) digits following the decimal point to consider in + /// arithmetic regardless of how many actually follow the decimal point as determined by + /// `weight`--the comment in the Postgres code linked above recommends using this only for + /// ignoring unnecessary trailing zeroes (as trimming nonzero digits means reducing the + /// precision of the value). + /// + /// Must be `>= 0`. + scale: i16, + }, +} + +// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L167-L170 +const SIGN_POS: u16 = 0x0000; +const SIGN_NEG: u16 = 0x4000; +const SIGN_NAN: u16 = 0xC000; // overflows i16 (C equivalent truncates from integer literal) + +/// Possible sign values for [PgNumeric]. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[repr(u16)] +pub(crate) enum PgNumericSign { + Positive = SIGN_POS, + Negative = SIGN_NEG, +} + +impl PgNumericSign { + fn try_from_u16(val: u16) -> Result { + match val { + SIGN_POS => Ok(PgNumericSign::Positive), + SIGN_NEG => Ok(PgNumericSign::Negative), + + SIGN_NAN => unreachable!("sign value for NaN passed to PgNumericSign"), + + _ => Err(format!("invalid value for PgNumericSign: {val:#04X}").into()), + } + } +} + +impl PgNumeric { + /// Equivalent value of `0::numeric`. + pub const ZERO: Self = PgNumeric::Number { + sign: PgNumericSign::Positive, + digits: vec![], + weight: 0, + scale: 0, + }; + + pub(crate) fn is_valid_digit(digit: i16) -> bool { + (0..10_000).contains(&digit) + } + + pub(crate) fn size_hint(decimal_digits: u64) -> usize { + let mut size_hint = Saturating(decimal_digits); + + // BigDecimal::digits() gives us base-10 digits, so we divide by 4 to get base-10000 digits + // and since this is just a hint we just always round up + size_hint /= 4; + size_hint += 1; + + // Times two bytes for each base-10000 digit + size_hint *= 2; + + // Plus `weight` and `scale` + size_hint += 8; + + usize::try_from(size_hint.0).unwrap_or(usize::MAX) + } + + pub(crate) fn decode(mut buf: &[u8]) -> Result { + // https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L874 + let num_digits = buf.get_u16(); + let weight = buf.get_i16(); + let sign = buf.get_u16(); + let scale = buf.get_i16(); + + if sign == SIGN_NAN { + Ok(PgNumeric::NotANumber) + } else { + let digits: Vec<_> = (0..num_digits).map(|_| buf.get_i16()).collect::<_>(); + + Ok(PgNumeric::Number { + sign: PgNumericSign::try_from_u16(sign)?, + scale, + weight, + digits, + }) + } + } + + /// ### Errors + /// + /// * If `digits.len()` overflows `i16` + /// * If any element in `digits` is greater than or equal to 10000 + pub(crate) fn encode(&self, buf: &mut PgArgumentBuffer) -> Result<(), String> { + match *self { + PgNumeric::Number { + ref digits, + sign, + scale, + weight, + } => { + let digits_len = i16::try_from(digits.len()).map_err(|_| { + format!( + "PgNumeric digits.len() ({}) should not overflow i16", + digits.len() + ) + })?; + + buf.extend(&digits_len.to_be_bytes()); + buf.extend(&weight.to_be_bytes()); + buf.extend(&(sign as i16).to_be_bytes()); + buf.extend(&scale.to_be_bytes()); + + for (i, &digit) in digits.iter().enumerate() { + if !Self::is_valid_digit(digit) { + return Err(format!("{i}th PgNumeric digit out of range: {digit}")); + } + + buf.extend(&digit.to_be_bytes()); + } + } + + PgNumeric::NotANumber => { + buf.extend(&0_i16.to_be_bytes()); + buf.extend(&0_i16.to_be_bytes()); + buf.extend(&SIGN_NAN.to_be_bytes()); + buf.extend(&0_i16.to_be_bytes()); + } + } + + Ok(()) + } +} diff --git a/patches/sqlx-postgres/src/types/oid.rs b/patches/sqlx-postgres/src/types/oid.rs new file mode 100644 index 000000000..04c5ef837 --- /dev/null +++ b/patches/sqlx-postgres/src/types/oid.rs @@ -0,0 +1,65 @@ +use byteorder::{BigEndian, ByteOrder}; +use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize}; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +/// The PostgreSQL [`OID`] type stores an object identifier, +/// used internally by PostgreSQL as primary keys for various system tables. +/// +/// [`OID`]: https://www.postgresql.org/docs/current/datatype-oid.html +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, Default)] +pub struct Oid( + /// The raw unsigned integer value sent over the wire + pub u32, +); + +impl Type for Oid { + fn type_info() -> PgTypeInfo { + PgTypeInfo::OID + } +} + +impl PgHasArrayType for Oid { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::OID_ARRAY + } +} + +impl Encode<'_, Postgres> for Oid { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend(&self.0.to_be_bytes()); + + Ok(IsNull::No) + } +} + +impl Decode<'_, Postgres> for Oid { + fn decode(value: PgValueRef<'_>) -> Result { + Ok(Self(match value.format() { + PgValueFormat::Binary => BigEndian::read_u32(value.as_bytes()?), + PgValueFormat::Text => value.as_str()?.parse()?, + })) + } +} + +impl Serialize for Oid { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.0.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for Oid { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + u32::deserialize(deserializer).map(Self) + } +} diff --git a/patches/sqlx-postgres/src/types/range.rs b/patches/sqlx-postgres/src/types/range.rs new file mode 100644 index 000000000..5e1346d86 --- /dev/null +++ b/patches/sqlx-postgres/src/types/range.rs @@ -0,0 +1,522 @@ +use std::fmt::{self, Debug, Display, Formatter}; +use std::ops::{Bound, Range, RangeBounds, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive}; + +use bitflags::bitflags; +use sqlx_core::bytes::Buf; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::type_info::PgTypeKind; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +// https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/include/utils/rangetypes.h#L35-L44 +bitflags! { + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + struct RangeFlags: u8 { + const EMPTY = 0x01; + const LB_INC = 0x02; + const UB_INC = 0x04; + const LB_INF = 0x08; + const UB_INF = 0x10; + const LB_NULL = 0x20; // not used + const UB_NULL = 0x40; // not used + const CONTAIN_EMPTY = 0x80; // internal + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct PgRange { + pub start: Bound, + pub end: Bound, +} + +impl From<[Bound; 2]> for PgRange { + fn from(v: [Bound; 2]) -> Self { + let [start, end] = v; + Self { start, end } + } +} + +impl From<(Bound, Bound)> for PgRange { + fn from(v: (Bound, Bound)) -> Self { + Self { + start: v.0, + end: v.1, + } + } +} + +impl From> for PgRange { + fn from(v: Range) -> Self { + Self { + start: Bound::Included(v.start), + end: Bound::Excluded(v.end), + } + } +} + +impl From> for PgRange { + fn from(v: RangeFrom) -> Self { + Self { + start: Bound::Included(v.start), + end: Bound::Unbounded, + } + } +} + +impl From> for PgRange { + fn from(v: RangeInclusive) -> Self { + let (start, end) = v.into_inner(); + Self { + start: Bound::Included(start), + end: Bound::Included(end), + } + } +} + +impl From> for PgRange { + fn from(v: RangeTo) -> Self { + Self { + start: Bound::Unbounded, + end: Bound::Excluded(v.end), + } + } +} + +impl From> for PgRange { + fn from(v: RangeToInclusive) -> Self { + Self { + start: Bound::Unbounded, + end: Bound::Included(v.end), + } + } +} + +impl RangeBounds for PgRange { + fn start_bound(&self) -> Bound<&T> { + match self.start { + Bound::Included(ref start) => Bound::Included(start), + Bound::Excluded(ref start) => Bound::Excluded(start), + Bound::Unbounded => Bound::Unbounded, + } + } + + fn end_bound(&self) -> Bound<&T> { + match self.end { + Bound::Included(ref end) => Bound::Included(end), + Bound::Excluded(ref end) => Bound::Excluded(end), + Bound::Unbounded => Bound::Unbounded, + } + } +} + +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INT4_RANGE + } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } +} + +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INT8_RANGE + } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } +} + +#[cfg(feature = "bigdecimal")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUM_RANGE + } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } +} + +#[cfg(feature = "rust_decimal")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUM_RANGE + } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } +} + +#[cfg(feature = "chrono")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::DATE_RANGE + } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } +} + +#[cfg(feature = "chrono")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TS_RANGE + } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } +} + +#[cfg(feature = "chrono")] +impl Type for PgRange> { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TSTZ_RANGE + } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::>(ty) + } +} + +#[cfg(feature = "time")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::DATE_RANGE + } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } +} + +#[cfg(feature = "time")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TS_RANGE + } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } +} + +#[cfg(feature = "time")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TSTZ_RANGE + } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } +} + +impl PgHasArrayType for PgRange { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INT4_RANGE_ARRAY + } +} + +impl PgHasArrayType for PgRange { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INT8_RANGE_ARRAY + } +} + +#[cfg(feature = "bigdecimal")] +impl PgHasArrayType for PgRange { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::NUM_RANGE_ARRAY + } +} + +#[cfg(feature = "rust_decimal")] +impl PgHasArrayType for PgRange { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::NUM_RANGE_ARRAY + } +} + +#[cfg(feature = "chrono")] +impl PgHasArrayType for PgRange { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::DATE_RANGE_ARRAY + } +} + +#[cfg(feature = "chrono")] +impl PgHasArrayType for PgRange { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::TS_RANGE_ARRAY + } +} + +#[cfg(feature = "chrono")] +impl PgHasArrayType for PgRange> { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::TSTZ_RANGE_ARRAY + } +} + +#[cfg(feature = "time")] +impl PgHasArrayType for PgRange { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::DATE_RANGE_ARRAY + } +} + +#[cfg(feature = "time")] +impl PgHasArrayType for PgRange { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::TS_RANGE_ARRAY + } +} + +#[cfg(feature = "time")] +impl PgHasArrayType for PgRange { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::TSTZ_RANGE_ARRAY + } +} + +impl<'q, T> Encode<'q, Postgres> for PgRange +where + T: Encode<'q, Postgres>, +{ + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + // https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/backend/utils/adt/rangetypes.c#L245 + + let mut flags = RangeFlags::empty(); + + flags |= match self.start { + Bound::Included(_) => RangeFlags::LB_INC, + Bound::Unbounded => RangeFlags::LB_INF, + Bound::Excluded(_) => RangeFlags::empty(), + }; + + flags |= match self.end { + Bound::Included(_) => RangeFlags::UB_INC, + Bound::Unbounded => RangeFlags::UB_INF, + Bound::Excluded(_) => RangeFlags::empty(), + }; + + buf.push(flags.bits()); + + if let Bound::Included(v) | Bound::Excluded(v) = &self.start { + buf.encode(v)?; + } + + if let Bound::Included(v) | Bound::Excluded(v) = &self.end { + buf.encode(v)?; + } + + // ranges are themselves never null + Ok(IsNull::No) + } +} + +impl<'r, T> Decode<'r, Postgres> for PgRange +where + T: Type + for<'a> Decode<'a, Postgres>, +{ + fn decode(value: PgValueRef<'r>) -> Result { + match value.format { + PgValueFormat::Binary => { + let element_ty = if let PgTypeKind::Range(element) = &value.type_info.0.kind() { + element + } else { + return Err(format!("unexpected non-range type {}", value.type_info).into()); + }; + + let mut buf = value.as_bytes()?; + + let mut start = Bound::Unbounded; + let mut end = Bound::Unbounded; + + let flags = RangeFlags::from_bits_truncate(buf.get_u8()); + + if flags.contains(RangeFlags::EMPTY) { + return Ok(PgRange { start, end }); + } + + if !flags.contains(RangeFlags::LB_INF) { + let value = + T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone())?)?; + + start = if flags.contains(RangeFlags::LB_INC) { + Bound::Included(value) + } else { + Bound::Excluded(value) + }; + } + + if !flags.contains(RangeFlags::UB_INF) { + let value = + T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone())?)?; + + end = if flags.contains(RangeFlags::UB_INC) { + Bound::Included(value) + } else { + Bound::Excluded(value) + }; + } + + Ok(PgRange { start, end }) + } + + PgValueFormat::Text => { + // https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/backend/utils/adt/rangetypes.c#L2046 + + let mut start = None; + let mut end = None; + + let s = value.as_str()?; + + // remember the bounds + let sb = s.as_bytes(); + let lower = sb[0] as char; + let upper = sb[sb.len() - 1] as char; + + // trim the wrapping braces/brackets + let s = &s[1..(s.len() - 1)]; + + let mut chars = s.chars(); + + let mut element = String::new(); + let mut done = false; + let mut quoted = false; + let mut in_quotes = false; + let mut in_escape = false; + let mut prev_ch = '\0'; + let mut count = 0; + + while !done { + element.clear(); + + loop { + match chars.next() { + Some(ch) => { + match ch { + _ if in_escape => { + element.push(ch); + in_escape = false; + } + + '"' if in_quotes => { + in_quotes = false; + } + + '"' => { + in_quotes = true; + quoted = true; + + if prev_ch == '"' { + element.push('"') + } + } + + '\\' if !in_escape => { + in_escape = true; + } + + ',' if !in_quotes => break, + + _ => { + element.push(ch); + } + } + prev_ch = ch; + } + + None => { + done = true; + break; + } + } + } + + count += 1; + if !element.is_empty() || quoted { + let value = Some(T::decode(PgValueRef { + type_info: T::type_info(), + format: PgValueFormat::Text, + value: Some(element.as_bytes()), + row: None, + })?); + + if count == 1 { + start = value; + } else if count == 2 { + end = value; + } else { + return Err("more than 2 elements found in a range".into()); + } + } + } + + let start = parse_bound(lower, start)?; + let end = parse_bound(upper, end)?; + + Ok(PgRange { start, end }) + } + } + } +} + +fn parse_bound(ch: char, value: Option) -> Result, BoxDynError> { + Ok(if let Some(value) = value { + match ch { + '(' | ')' => Bound::Excluded(value), + '[' | ']' => Bound::Included(value), + + _ => { + return Err(format!( + "expected `(`, ')', '[', or `]` but found `{ch}` for range literal" + ) + .into()); + } + } + } else { + Bound::Unbounded + }) +} + +impl Display for PgRange +where + T: Display, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match &self.start { + Bound::Unbounded => f.write_str("(,")?, + Bound::Excluded(v) => write!(f, "({v},")?, + Bound::Included(v) => write!(f, "[{v},")?, + } + + match &self.end { + Bound::Unbounded => f.write_str(")")?, + Bound::Excluded(v) => write!(f, "{v})")?, + Bound::Included(v) => write!(f, "{v}]")?, + } + + Ok(()) + } +} + +fn range_compatible>(ty: &PgTypeInfo) -> bool { + // we require the declared type to be a _range_ with an + // element type that is acceptable + if let PgTypeKind::Range(element) = &ty.kind() { + return E::compatible(element); + } + + false +} diff --git a/patches/sqlx-postgres/src/types/record.rs b/patches/sqlx-postgres/src/types/record.rs new file mode 100644 index 000000000..c4eb63936 --- /dev/null +++ b/patches/sqlx-postgres/src/types/record.rs @@ -0,0 +1,205 @@ +use sqlx_core::bytes::Buf; + +use crate::decode::Decode; +use crate::encode::Encode; +use crate::error::{mismatched_types, BoxDynError}; +use crate::type_info::TypeInfo; +use crate::type_info::{PgType, PgTypeKind}; +use crate::types::Oid; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +#[doc(hidden)] +pub struct PgRecordEncoder<'a> { + buf: &'a mut PgArgumentBuffer, + off: usize, + num: u32, +} + +impl<'a> PgRecordEncoder<'a> { + #[doc(hidden)] + pub fn new(buf: &'a mut PgArgumentBuffer) -> Self { + let off = buf.len(); + + // reserve space for a field count + buf.extend(&(0_u32).to_be_bytes()); + + Self { buf, off, num: 0 } + } + + #[doc(hidden)] + pub fn finish(&mut self) { + // fill in the record length + self.buf[self.off..(self.off + 4)].copy_from_slice(&self.num.to_be_bytes()); + } + + #[doc(hidden)] + pub fn encode<'q, T>(&mut self, value: T) -> Result<&mut Self, BoxDynError> + where + 'a: 'q, + T: Encode<'q, Postgres> + Type, + { + let ty = value.produces().unwrap_or_else(T::type_info); + + if let PgType::DeclareWithName(name) = ty.0 { + // push a hole for this type ID + // to be filled in on query execution + self.buf.patch_type_by_name(&name); + } else { + // write type id + self.buf.extend(&ty.0.oid().0.to_be_bytes()); + } + + self.buf.encode(value)?; + self.num += 1; + + Ok(self) + } +} + +#[doc(hidden)] +pub struct PgRecordDecoder<'r> { + buf: &'r [u8], + typ: PgTypeInfo, + fmt: PgValueFormat, + ind: usize, +} + +impl<'r> PgRecordDecoder<'r> { + #[doc(hidden)] + pub fn new(value: PgValueRef<'r>) -> Result { + let fmt = value.format(); + let mut buf = value.as_bytes()?; + let typ = value.type_info; + + match fmt { + PgValueFormat::Binary => { + let _len = buf.get_u32(); + } + + PgValueFormat::Text => { + // remove the enclosing `(` .. `)` + buf = &buf[1..(buf.len() - 1)]; + } + } + + Ok(Self { + buf, + fmt, + typ, + ind: 0, + }) + } + + #[doc(hidden)] + pub fn try_decode(&mut self) -> Result + where + T: for<'a> Decode<'a, Postgres> + Type, + { + if self.buf.is_empty() { + return Err(format!("no field `{0}` found on record", self.ind).into()); + } + + match self.fmt { + PgValueFormat::Binary => { + let element_type_oid = Oid(self.buf.get_u32()); + let element_type_opt = match self.typ.0.kind() { + PgTypeKind::Simple if self.typ.0 == PgType::Record => { + PgTypeInfo::try_from_oid(element_type_oid) + } + + PgTypeKind::Composite(fields) => { + let ty = fields[self.ind].1.clone(); + if ty.0.oid() != element_type_oid { + return Err("unexpected mismatch of composite type information".into()); + } + + Some(ty) + } + + _ => { + return Err( + "unexpected non-composite type being decoded as a composite type" + .into(), + ); + } + }; + + if let Some(ty) = &element_type_opt { + if !ty.is_null() && !T::compatible(ty) { + return Err(mismatched_types::(ty)); + } + } + + let element_type = + element_type_opt + .ok_or_else(|| BoxDynError::from(format!("custom types in records are not fully supported yet: failed to retrieve type info for field {} with type oid {}", self.ind, element_type_oid.0)))?; + + self.ind += 1; + + T::decode(PgValueRef::get(&mut self.buf, self.fmt, element_type)?) + } + + PgValueFormat::Text => { + let mut element = String::new(); + let mut quoted = false; + let mut in_quotes = false; + let mut in_escape = false; + let mut prev_ch = '\0'; + + while !self.buf.is_empty() { + let ch = self.buf.get_u8() as char; + match ch { + _ if in_escape => { + element.push(ch); + in_escape = false; + } + + '"' if in_quotes => { + in_quotes = false; + } + + '"' => { + in_quotes = true; + quoted = true; + + if prev_ch == '"' { + element.push('"') + } + } + + '\\' if !in_escape => { + in_escape = true; + } + + ',' if !in_quotes => break, + + _ => { + element.push(ch); + } + } + prev_ch = ch; + } + + let buf = if element.is_empty() && !quoted { + // completely empty input means NULL + None + } else { + Some(element.as_bytes()) + }; + + // NOTE: we do not call [`accepts`] or give a chance to from a user as + // TEXT sequences are not strongly typed + + T::decode(PgValueRef { + // NOTE: We pass `0` as the type ID because we don't have a reasonable value + // we could use. + type_info: PgTypeInfo::with_oid(Oid(0)), + format: self.fmt, + value: buf, + row: None, + }) + } + } + } +} diff --git a/patches/sqlx-postgres/src/types/rust_decimal-range.md b/patches/sqlx-postgres/src/types/rust_decimal-range.md new file mode 100644 index 000000000..f986d616f --- /dev/null +++ b/patches/sqlx-postgres/src/types/rust_decimal-range.md @@ -0,0 +1,10 @@ +#### Note: `rust_decimal::Decimal` Has a Smaller Range than `NUMERIC` +`NUMERIC` is can have up to 131,072 digits before the decimal point, and 16,384 digits after it. +See [Section 8.1, Numeric Types] of the Postgres manual for details. + +However, `rust_decimal::Decimal` is limited to a maximum absolute magnitude of 296 - 1, +a number with 67 decimal digits, and a minimum absolute magnitude of 10-28, a number with, unsurprisingly, +28 decimal digits. + +Thus, in contrast with `BigDecimal`, `NUMERIC` can actually represent every possible value of `rust_decimal::Decimal`, +but not the other way around. This means that encoding should never fail, but decoding can. diff --git a/patches/sqlx-postgres/src/types/rust_decimal.rs b/patches/sqlx-postgres/src/types/rust_decimal.rs new file mode 100644 index 000000000..8321e8281 --- /dev/null +++ b/patches/sqlx-postgres/src/types/rust_decimal.rs @@ -0,0 +1,493 @@ +use rust_decimal::Decimal; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::numeric::{PgNumeric, PgNumericSign}; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +use rust_decimal::MathematicalOps; + +impl Type for Decimal { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUMERIC + } +} + +impl PgHasArrayType for Decimal { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::NUMERIC_ARRAY + } +} + +impl TryFrom for Decimal { + type Error = BoxDynError; + + fn try_from(numeric: PgNumeric) -> Result { + Decimal::try_from(&numeric) + } +} + +impl TryFrom<&'_ PgNumeric> for Decimal { + type Error = BoxDynError; + + fn try_from(numeric: &'_ PgNumeric) -> Result { + let (digits, sign, mut weight, scale) = match *numeric { + PgNumeric::Number { + ref digits, + sign, + weight, + scale, + } => (digits, sign, weight, scale), + + PgNumeric::NotANumber => { + return Err("Decimal does not support NaN values".into()); + } + }; + + if digits.is_empty() { + // Postgres returns an empty digit array for 0 + return Ok(Decimal::ZERO); + } + + let scale = u32::try_from(scale) + .map_err(|_| format!("invalid scale value for Pg NUMERIC: {scale}"))?; + + let mut value = Decimal::ZERO; + + // Sum over `digits`, multiply each by its weight and add it to `value`. + for &digit in digits { + let mul = Decimal::from(10_000i16) + .checked_powi(weight as i64) + .ok_or("value not representable as rust_decimal::Decimal")?; + + let part = Decimal::from(digit) * mul; + + value = value + .checked_add(part) + .ok_or("value not representable as rust_decimal::Decimal")?; + + weight = weight.checked_sub(1).ok_or("weight underflowed")?; + } + + match sign { + PgNumericSign::Positive => value.set_sign_positive(true), + PgNumericSign::Negative => value.set_sign_negative(true), + } + + value.rescale(scale); + + Ok(value) + } +} + +impl From for PgNumeric { + fn from(value: Decimal) -> Self { + PgNumeric::from(&value) + } +} + +// This impl is effectively infallible because `NUMERIC` has a greater range than `Decimal`. +impl From<&'_ Decimal> for PgNumeric { + // Impl has been manually validated. + #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)] + fn from(decimal: &Decimal) -> Self { + if Decimal::is_zero(decimal) { + return PgNumeric::ZERO; + } + + assert!( + (0u32..=28).contains(&decimal.scale()), + "decimal scale out of range {:?}", + decimal.unpack(), + ); + + // Cannot overflow: always in the range [0, 28] + let scale = decimal.scale() as u16; + + let mut mantissa = decimal.mantissa().unsigned_abs(); + + // If our scale is not a multiple of 4, we need to go to the next multiple. + let groups_diff = scale % 4; + if groups_diff > 0 { + let remainder = 4 - groups_diff as u32; + let power = 10u32.pow(remainder) as u128; + + // Impossible to overflow; 0 <= mantissa <= 2^96, + // and we're multiplying by at most 1,000 (giving us a result < 2^106) + mantissa *= power; + } + + // Array to store max mantissa of Decimal in Postgres decimal format. + let mut digits = Vec::with_capacity(8); + + // Convert to base-10000. + while mantissa != 0 { + // Cannot overflow or wrap because of the modulus + digits.push((mantissa % 10_000) as i16); + mantissa /= 10_000; + } + + // We started with the low digits first, but they should actually be at the end. + digits.reverse(); + + // Cannot overflow: strictly smaller than `scale`. + let digits_after_decimal = scale.div_ceil(4) as i16; + + // `mantissa` contains at most 29 decimal digits (log10(2^96)), + // split into at most 8 4-digit segments. + assert!( + digits.len() <= 8, + "digits.len() out of range: {}; unpacked: {:?}", + digits.len(), + decimal.unpack() + ); + + // Cannot overflow; at most 8 + let num_digits = digits.len() as i16; + + // Find how many 4-digit segments should go before the decimal point. + // `weight = 0` puts just `digit[0]` before the decimal point, and the rest after. + let weight = num_digits - digits_after_decimal - 1; + + // Remove non-significant zeroes. + while let Some(&0) = digits.last() { + digits.pop(); + } + + PgNumeric::Number { + sign: match decimal.is_sign_negative() { + false => PgNumericSign::Positive, + true => PgNumericSign::Negative, + }, + // Cannot overflow; between 0 and 28 + scale: scale as i16, + weight, + digits, + } + } +} + +impl Encode<'_, Postgres> for Decimal { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + PgNumeric::from(self).encode(buf)?; + + Ok(IsNull::No) + } +} + +#[doc=include_str!("rust_decimal-range.md")] +impl Decode<'_, Postgres> for Decimal { + fn decode(value: PgValueRef<'_>) -> Result { + match value.format() { + PgValueFormat::Binary => PgNumeric::decode(value.as_bytes()?)?.try_into(), + PgValueFormat::Text => Ok(value.as_str()?.parse::()?), + } + } +} + +#[cfg(test)] +mod tests { + use super::{Decimal, PgNumeric, PgNumericSign}; + use std::convert::TryFrom; + + #[test] + fn zero() { + let zero: Decimal = "0".parse().unwrap(); + + assert_eq!(PgNumeric::from(&zero), PgNumeric::ZERO,); + + assert_eq!(Decimal::try_from(&PgNumeric::ZERO).unwrap(), Decimal::ZERO); + } + + #[test] + fn one() { + let one: Decimal = "1".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![1] + } + ); + } + + #[test] + fn ten() { + let ten: Decimal = "10".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&ten).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![10] + } + ); + } + + #[test] + fn one_hundred() { + let one_hundred: Decimal = "100".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one_hundred).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![100] + } + ); + } + + #[test] + fn ten_thousand() { + // Decimal doesn't normalize here + let ten_thousand: Decimal = "10000".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&ten_thousand).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 1, + digits: vec![1] + } + ); + } + + #[test] + fn two_digits() { + let two_digits: Decimal = "12345".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&two_digits).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 1, + digits: vec![1, 2345] + } + ); + } + + #[test] + fn one_tenth() { + let one_tenth: Decimal = "0.1".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one_tenth).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 1, + weight: -1, + digits: vec![1000] + } + ); + } + + #[test] + fn decimal_1() { + let decimal: Decimal = "1.2345".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 4, + weight: 0, + digits: vec![1, 2345] + } + ); + } + + #[test] + fn decimal_2() { + let decimal: Decimal = "0.12345".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 5, + weight: -1, + digits: vec![1234, 5000] + } + ); + } + + #[test] + fn decimal_3() { + let decimal: Decimal = "0.01234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 5, + weight: -1, + digits: vec![0123, 4000] + } + ); + } + + #[test] + fn decimal_4() { + let decimal: Decimal = "12345.67890".parse().unwrap(); + let expected_numeric = PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 5, + weight: 1, + digits: vec![1, 2345, 6789], + }; + assert_eq!(PgNumeric::try_from(&decimal).unwrap(), expected_numeric); + + let actual_decimal = Decimal::try_from(expected_numeric).unwrap(); + assert_eq!(actual_decimal, decimal); + assert_eq!(actual_decimal.mantissa(), 1234567890); + assert_eq!(actual_decimal.scale(), 5); + } + + #[test] + fn one_digit_decimal() { + let one_digit_decimal: Decimal = "0.00001234".parse().unwrap(); + let expected_numeric = PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 8, + weight: -2, + digits: vec![1234], + }; + assert_eq!( + PgNumeric::try_from(&one_digit_decimal).unwrap(), + expected_numeric + ); + + let actual_decimal = Decimal::try_from(expected_numeric).unwrap(); + assert_eq!(actual_decimal, one_digit_decimal); + assert_eq!(actual_decimal.mantissa(), 1234); + assert_eq!(actual_decimal.scale(), 8); + } + + #[test] + fn max_value() { + let expected_numeric = PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 7, + digits: vec![7, 9228, 1625, 1426, 4337, 5935, 4395, 0335], + }; + assert_eq!( + PgNumeric::try_from(&Decimal::MAX).unwrap(), + expected_numeric + ); + + let actual_decimal = Decimal::try_from(expected_numeric).unwrap(); + assert_eq!(actual_decimal, Decimal::MAX); + // Value split by 10,000's to match the expected digits[] + assert_eq!( + actual_decimal.mantissa(), + 7_9228_1625_1426_4337_5935_4395_0335 + ); + assert_eq!(actual_decimal.scale(), 0); + } + + #[test] + fn max_value_max_scale() { + let mut max_value_max_scale = Decimal::MAX; + max_value_max_scale.set_scale(28).unwrap(); + + let expected_numeric = PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 28, + weight: 0, + digits: vec![7, 9228, 1625, 1426, 4337, 5935, 4395, 0335], + }; + assert_eq!( + PgNumeric::try_from(&max_value_max_scale).unwrap(), + expected_numeric + ); + + let actual_decimal = Decimal::try_from(expected_numeric).unwrap(); + assert_eq!(actual_decimal, max_value_max_scale); + assert_eq!( + actual_decimal.mantissa(), + 79_228_162_514_264_337_593_543_950_335 + ); + assert_eq!(actual_decimal.scale(), 28); + } + + #[test] + fn issue_423_four_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let four_digit: Decimal = "1234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&four_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![1234] + } + ); + } + + #[test] + fn issue_423_negative_four_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let negative_four_digit: Decimal = "-1234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&negative_four_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Negative, + scale: 0, + weight: 0, + digits: vec![1234] + } + ); + } + + #[test] + fn issue_423_eight_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let eight_digit: Decimal = "12345678".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&eight_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 1, + digits: vec![1234, 5678] + } + ); + } + + #[test] + fn issue_423_negative_eight_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let negative_eight_digit: Decimal = "-12345678".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&negative_eight_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Negative, + scale: 0, + weight: 1, + digits: vec![1234, 5678] + } + ); + } + + #[test] + fn issue_2247_trailing_zeros() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/2247 + let one_hundred: Decimal = "100.00".parse().unwrap(); + let expected_numeric = PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 2, + weight: 0, + digits: vec![100], + }; + assert_eq!(PgNumeric::try_from(&one_hundred).unwrap(), expected_numeric); + + let actual_decimal = Decimal::try_from(expected_numeric).unwrap(); + assert_eq!(actual_decimal, one_hundred); + assert_eq!(actual_decimal.mantissa(), 10000); + assert_eq!(actual_decimal.scale(), 2); + } +} diff --git a/patches/sqlx-postgres/src/types/str.rs b/patches/sqlx-postgres/src/types/str.rs new file mode 100644 index 000000000..ca7e20a55 --- /dev/null +++ b/patches/sqlx-postgres/src/types/str.rs @@ -0,0 +1,148 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::array_compatible; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres}; +use std::borrow::Cow; + +impl Type for str { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TEXT + } + + fn compatible(ty: &PgTypeInfo) -> bool { + [ + PgTypeInfo::TEXT, + PgTypeInfo::NAME, + PgTypeInfo::BPCHAR, + PgTypeInfo::VARCHAR, + PgTypeInfo::UNKNOWN, + PgTypeInfo::with_name("citext"), + ] + .contains(ty) + } +} + +impl Type for Cow<'_, str> { + fn type_info() -> PgTypeInfo { + <&str as Type>::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + <&str as Type>::compatible(ty) + } +} + +impl Type for Box { + fn type_info() -> PgTypeInfo { + <&str as Type>::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + <&str as Type>::compatible(ty) + } +} + +impl Type for String { + fn type_info() -> PgTypeInfo { + <&str as Type>::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + <&str as Type>::compatible(ty) + } +} + +impl PgHasArrayType for &'_ str { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::TEXT_ARRAY + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + array_compatible::<&str>(ty) + } +} + +impl PgHasArrayType for Cow<'_, str> { + fn array_type_info() -> PgTypeInfo { + <&str as PgHasArrayType>::array_type_info() + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + <&str as PgHasArrayType>::array_compatible(ty) + } +} + +impl PgHasArrayType for Box { + fn array_type_info() -> PgTypeInfo { + <&str as PgHasArrayType>::array_type_info() + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + <&str as PgHasArrayType>::array_compatible(ty) + } +} + +impl PgHasArrayType for String { + fn array_type_info() -> PgTypeInfo { + <&str as PgHasArrayType>::array_type_info() + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + <&str as PgHasArrayType>::array_compatible(ty) + } +} + +impl Encode<'_, Postgres> for &'_ str { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + buf.extend(self.as_bytes()); + + Ok(IsNull::No) + } +} + +impl Encode<'_, Postgres> for Cow<'_, str> { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + match self { + Cow::Borrowed(str) => <&str as Encode>::encode(*str, buf), + Cow::Owned(str) => <&str as Encode>::encode(&**str, buf), + } + } +} + +impl Encode<'_, Postgres> for Box { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + <&str as Encode>::encode(&**self, buf) + } +} + +impl Encode<'_, Postgres> for String { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + <&str as Encode>::encode(&**self, buf) + } +} + +impl<'r> Decode<'r, Postgres> for &'r str { + fn decode(value: PgValueRef<'r>) -> Result { + value.as_str() + } +} + +impl<'r> Decode<'r, Postgres> for Cow<'r, str> { + fn decode(value: PgValueRef<'r>) -> Result { + Ok(Cow::Borrowed(value.as_str()?)) + } +} + +impl<'r> Decode<'r, Postgres> for Box { + fn decode(value: PgValueRef<'r>) -> Result { + Ok(Box::from(value.as_str()?)) + } +} + +impl Decode<'_, Postgres> for String { + fn decode(value: PgValueRef<'_>) -> Result { + Ok(value.as_str()?.to_owned()) + } +} diff --git a/patches/sqlx-postgres/src/types/text.rs b/patches/sqlx-postgres/src/types/text.rs new file mode 100644 index 000000000..b5b0a5ed7 --- /dev/null +++ b/patches/sqlx-postgres/src/types/text.rs @@ -0,0 +1,40 @@ +use crate::{PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres}; +use sqlx_core::decode::Decode; +use sqlx_core::encode::{Encode, IsNull}; +use sqlx_core::error::BoxDynError; +use sqlx_core::types::{Text, Type}; +use std::fmt::Display; +use std::str::FromStr; + +use std::io::Write; + +impl Type for Text { + fn type_info() -> PgTypeInfo { + >::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + >::compatible(ty) + } +} + +impl<'q, T> Encode<'q, Postgres> for Text +where + T: Display, +{ + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + write!(**buf, "{}", self.0)?; + Ok(IsNull::No) + } +} + +impl<'r, T> Decode<'r, Postgres> for Text +where + T: FromStr, + BoxDynError: From<::Err>, +{ + fn decode(value: PgValueRef<'r>) -> Result { + let s: &str = Decode::::decode(value)?; + Ok(Self(s.parse()?)) + } +} diff --git a/patches/sqlx-postgres/src/types/time/date.rs b/patches/sqlx-postgres/src/types/time/date.rs new file mode 100644 index 000000000..2afa57ee0 --- /dev/null +++ b/patches/sqlx-postgres/src/types/time/date.rs @@ -0,0 +1,52 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::time::PG_EPOCH; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use std::mem; +use time::macros::format_description; +use time::{Date, Duration}; + +impl Type for Date { + fn type_info() -> PgTypeInfo { + PgTypeInfo::DATE + } +} + +impl PgHasArrayType for Date { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::DATE_ARRAY + } +} + +impl Encode<'_, Postgres> for Date { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + // DATE is encoded as number of days since epoch (2000-01-01) + let days: i32 = (*self - PG_EPOCH).whole_days().try_into().map_err(|_| { + format!("value {self:?} would overflow binary encoding for Postgres DATE") + })?; + Encode::::encode(days, buf) + } + + fn size_hint(&self) -> usize { + mem::size_of::() + } +} + +impl<'r> Decode<'r, Postgres> for Date { + fn decode(value: PgValueRef<'r>) -> Result { + Ok(match value.format() { + PgValueFormat::Binary => { + // DATE is encoded as the days since epoch + let days: i32 = Decode::::decode(value)?; + PG_EPOCH + Duration::days(days.into()) + } + + PgValueFormat::Text => Date::parse( + value.as_str()?, + &format_description!("[year]-[month]-[day]"), + )?, + }) + } +} diff --git a/patches/sqlx-postgres/src/types/time/datetime.rs b/patches/sqlx-postgres/src/types/time/datetime.rs new file mode 100644 index 000000000..3484116bd --- /dev/null +++ b/patches/sqlx-postgres/src/types/time/datetime.rs @@ -0,0 +1,108 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::time::PG_EPOCH; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use std::borrow::Cow; +use std::mem; +use time::macros::format_description; +use time::macros::offset; +use time::{Duration, OffsetDateTime, PrimitiveDateTime}; + +impl Type for PrimitiveDateTime { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TIMESTAMP + } +} + +impl Type for OffsetDateTime { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TIMESTAMPTZ + } +} + +impl PgHasArrayType for PrimitiveDateTime { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::TIMESTAMP_ARRAY + } +} + +impl PgHasArrayType for OffsetDateTime { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::TIMESTAMPTZ_ARRAY + } +} + +impl Encode<'_, Postgres> for PrimitiveDateTime { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + // TIMESTAMP is encoded as the microseconds since the epoch + let micros: i64 = (*self - PG_EPOCH.midnight()) + .whole_microseconds() + .try_into() + .map_err(|_| { + format!("value {self:?} would overflow binary encoding for Postgres TIME") + })?; + Encode::::encode(micros, buf) + } + + fn size_hint(&self) -> usize { + mem::size_of::() + } +} + +impl<'r> Decode<'r, Postgres> for PrimitiveDateTime { + fn decode(value: PgValueRef<'r>) -> Result { + Ok(match value.format() { + PgValueFormat::Binary => { + // TIMESTAMP is encoded as the microseconds since the epoch + let us = Decode::::decode(value)?; + PG_EPOCH.midnight() + Duration::microseconds(us) + } + + PgValueFormat::Text => { + let s = value.as_str()?; + + // If there is no decimal point we need to add one. + let s = if s.contains('.') { + Cow::Borrowed(s) + } else { + Cow::Owned(format!("{s}.0")) + }; + + // Contains a time-zone specifier + // This is given for timestamptz for some reason + // Postgres already guarantees this to always be UTC + if s.contains('+') { + PrimitiveDateTime::parse(&s, &format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond][offset_hour]"))? + } else { + PrimitiveDateTime::parse( + &s, + &format_description!( + "[year]-[month]-[day] [hour]:[minute]:[second].[subsecond]" + ), + )? + } + } + }) + } +} + +impl Encode<'_, Postgres> for OffsetDateTime { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + let utc = self.to_offset(offset!(UTC)); + let primitive = PrimitiveDateTime::new(utc.date(), utc.time()); + + Encode::::encode(primitive, buf) + } + + fn size_hint(&self) -> usize { + mem::size_of::() + } +} + +impl<'r> Decode<'r, Postgres> for OffsetDateTime { + fn decode(value: PgValueRef<'r>) -> Result { + Ok(>::decode(value)?.assume_utc()) + } +} diff --git a/patches/sqlx-postgres/src/types/time/mod.rs b/patches/sqlx-postgres/src/types/time/mod.rs new file mode 100644 index 000000000..9a45ba833 --- /dev/null +++ b/patches/sqlx-postgres/src/types/time/mod.rs @@ -0,0 +1,9 @@ +mod date; +mod datetime; + +// Parent module is named after the `time` crate, this module is named after the `TIME` SQL type. +#[allow(clippy::module_inception)] +mod time; + +#[rustfmt::skip] +const PG_EPOCH: ::time::Date = ::time::macros::date!(2000-1-1); diff --git a/patches/sqlx-postgres/src/types/time/time.rs b/patches/sqlx-postgres/src/types/time/time.rs new file mode 100644 index 000000000..635170d14 --- /dev/null +++ b/patches/sqlx-postgres/src/types/time/time.rs @@ -0,0 +1,53 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use std::mem; +use time::macros::format_description; +use time::{Duration, Time}; + +impl Type for Time { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TIME + } +} + +impl PgHasArrayType for Time { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::TIME_ARRAY + } +} + +impl Encode<'_, Postgres> for Time { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + // TIME is encoded as the microseconds since midnight. + // + // A truncating cast is fine because `self - Time::MIDNIGHT` cannot exceed a span of 24 hours. + #[allow(clippy::cast_possible_truncation)] + let micros: i64 = (*self - Time::MIDNIGHT).whole_microseconds() as i64; + Encode::::encode(micros, buf) + } + + fn size_hint(&self) -> usize { + mem::size_of::() + } +} + +impl<'r> Decode<'r, Postgres> for Time { + fn decode(value: PgValueRef<'r>) -> Result { + Ok(match value.format() { + PgValueFormat::Binary => { + // TIME is encoded as the microseconds since midnight + let us = Decode::::decode(value)?; + Time::MIDNIGHT + Duration::microseconds(us) + } + + PgValueFormat::Text => Time::parse( + value.as_str()?, + // Postgres will not include the subsecond part if it's zero. + &format_description!("[hour]:[minute]:[second][optional [.[subsecond]]]"), + )?, + }) + } +} diff --git a/patches/sqlx-postgres/src/types/time_tz.rs b/patches/sqlx-postgres/src/types/time_tz.rs new file mode 100644 index 000000000..e3de79ea5 --- /dev/null +++ b/patches/sqlx-postgres/src/types/time_tz.rs @@ -0,0 +1,176 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use byteorder::{BigEndian, ReadBytesExt}; +use std::io::Cursor; +use std::mem; + +#[cfg(feature = "time")] +type DefaultTime = ::time::Time; + +#[cfg(all(not(feature = "time"), feature = "chrono"))] +type DefaultTime = ::chrono::NaiveTime; + +#[cfg(feature = "time")] +type DefaultOffset = ::time::UtcOffset; + +#[cfg(all(not(feature = "time"), feature = "chrono"))] +type DefaultOffset = ::chrono::FixedOffset; + +/// Represents a moment of time, in a specified timezone. +/// +/// # Warning +/// +/// `PgTimeTz` provides `TIMETZ` and is supported only for reading from legacy databases. +/// [PostgreSQL recommends] to use `TIMESTAMPTZ` instead. +/// +/// [PostgreSQL recommends]: https://wiki.postgresql.org/wiki/Don't_Do_This#Don.27t_use_timetz +#[derive(Debug, PartialEq, Clone, Copy)] +pub struct PgTimeTz