diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 5a93f6f27b43..5578517ec359 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -38,19 +38,8 @@ runs: rustup toolchain install ${{ inputs.rust-version }} rustup default ${{ inputs.rust-version }} rustup component add rustfmt - - name: Disable debuginfo generation - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - shell: bash - run: echo "RUSTFLAGS=-C debuginfo=1" >> $GITHUB_ENV - - name: Disable incremental compilation - # Disable incremental compilation to save diskspace (the CI doesn't recompile modified files) - # https://github.com/apache/arrow-datafusion/issues/6676 - shell: bash - run: echo "CARGO_INCREMENTAL=0" >> $GITHUB_ENV - - name: Enable backtraces - shell: bash - run: echo "RUST_BACKTRACE=1" >> $GITHUB_ENV + - name: Configure rust runtime env + uses: ./.github/actions/setup-rust-runtime - name: Fixup git permissions # https://github.com/actions/checkout/issues/766 shell: bash diff --git a/.github/actions/setup-macos-builder/action.yaml b/.github/actions/setup-macos-builder/action.yaml new file mode 100644 index 000000000000..02419f617942 --- /dev/null +++ b/.github/actions/setup-macos-builder/action.yaml @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Prepare Rust Builder for MacOS +description: 'Prepare Rust Build Environment for MacOS' +inputs: + rust-version: + description: 'version of rust to install (e.g. stable)' + required: true + default: 'stable' +runs: + using: "composite" + steps: + - name: Install protobuf compiler + shell: bash + run: | + mkdir -p $HOME/d/protoc + cd $HOME/d/protoc + export PROTO_ZIP="protoc-21.4-osx-x86_64.zip" + curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v21.4/$PROTO_ZIP + unzip $PROTO_ZIP + echo "$HOME/d/protoc/bin" >> $GITHUB_PATH + export PATH=$PATH:$HOME/d/protoc/bin + protoc --version + - name: Setup Rust toolchain + shell: bash + run: | + rustup update stable + rustup toolchain install stable + rustup default stable + rustup component add rustfmt + - name: Configure rust runtime env + uses: ./.github/actions/setup-rust-runtime diff --git a/.github/actions/setup-rust-runtime/action.yaml b/.github/actions/setup-rust-runtime/action.yaml new file mode 100644 index 000000000000..90e09a957cd4 --- /dev/null +++ b/.github/actions/setup-rust-runtime/action.yaml @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Setup Rust Runtime +description: 'Setup Rust Runtime Environment' +runs: + using: "composite" + steps: + - name: Run sccache-cache + uses: mozilla-actions/sccache-action@v0.0.3 + - name: Configure runtime env + shell: bash + # do not produce debug symbols to keep memory usage down + # hardcoding other profile params to avoid profile override values + # More on Cargo profiles https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings + # + # Set debuginfo=line-tables-only as debuginfo=0 causes immensely slow build + # See for more details: https://github.com/rust-lang/rust/issues/119560 + # + # set RUST_MIN_STACK to avoid rust stack overflows on tpc-ds tests + run: | + echo "RUSTC_WRAPPER=sccache" >> $GITHUB_ENV + echo "SCCACHE_GHA_ENABLED=true" >> $GITHUB_ENV + echo "RUST_BACKTRACE=1" >> $GITHUB_ENV + echo "RUST_MIN_STACK=3000000" >> $GITHUB_ENV + echo "RUST_FLAGS=-C debuginfo=line-tables-only -C incremental=false" >> $GITHUB_ENV + diff --git a/.github/actions/setup-windows-builder/action.yaml b/.github/actions/setup-windows-builder/action.yaml new file mode 100644 index 000000000000..9ab5c4a8b1bb --- /dev/null +++ b/.github/actions/setup-windows-builder/action.yaml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Prepare Rust Builder for Windows +description: 'Prepare Rust Build Environment for Windows' +inputs: + rust-version: + description: 'version of rust to install (e.g. stable)' + required: true + default: 'stable' +runs: + using: "composite" + steps: + - name: Install protobuf compiler + shell: bash + run: | + mkdir -p $HOME/d/protoc + cd $HOME/d/protoc + export PROTO_ZIP="protoc-21.4-win64.zip" + curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v21.4/$PROTO_ZIP + unzip $PROTO_ZIP + export PATH=$PATH:$HOME/d/protoc/bin + protoc.exe --version + - name: Setup Rust toolchain + shell: bash + run: | + rustup update stable + rustup toolchain install stable + rustup default stable + rustup component add rustfmt + - name: Configure rust runtime env + uses: ./.github/actions/setup-rust-runtime diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 622521a6fbc7..62992e7acf68 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -96,17 +96,9 @@ jobs: - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: - rust-version: stable + rust-version: stable - name: Run tests (excluding doctests) run: cargo test --lib --tests --bins --features avro,json,backtrace - env: - # do not produce debug symbols to keep memory usage down - # hardcoding other profile params to avoid profile override values - # More on Cargo profiles https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings - RUSTFLAGS: "-C debuginfo=0 -C opt-level=0 -C incremental=false -C codegen-units=256" - RUST_BACKTRACE: "1" - # avoid rust stack overflows on tpc-ds tests - RUST_MINSTACK: "3000000" - name: Verify Working Directory Clean run: git diff --exit-code @@ -284,24 +276,8 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - name: Install protobuf compiler - shell: bash - run: | - mkdir -p $HOME/d/protoc - cd $HOME/d/protoc - export PROTO_ZIP="protoc-21.4-win64.zip" - curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v21.4/$PROTO_ZIP - unzip $PROTO_ZIP - export PATH=$PATH:$HOME/d/protoc/bin - protoc.exe --version - # TODO: this won't cache anything, which is expensive. Setup this action - # with a OS-dependent path. - name: Setup Rust toolchain - run: | - rustup update stable - rustup toolchain install stable - rustup default stable - rustup component add rustfmt + uses: ./.github/actions/setup-windows-builder - name: Run tests (excluding doctests) shell: bash run: | @@ -309,55 +285,22 @@ jobs: cargo test --lib --tests --bins --features avro,json,backtrace cd datafusion-cli cargo test --lib --tests --bins --all-features - env: - # do not produce debug symbols to keep memory usage down - # use higher optimization level to overcome Windows rust slowness for tpc-ds - # and speed builds: https://github.com/apache/arrow-datafusion/issues/8696 - # Cargo profile docs https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings - RUSTFLAGS: "-C debuginfo=0 -C opt-level=1 -C target-feature=+crt-static -C incremental=false -C codegen-units=256" - RUST_BACKTRACE: "1" - # avoid rust stack overflows on tpc-ds tests - RUST_MINSTACK: "3000000" + macos: - name: cargo test (mac) + name: cargo test (macos) runs-on: macos-latest steps: - uses: actions/checkout@v4 with: - submodules: true - - name: Install protobuf compiler - shell: bash - run: | - mkdir -p $HOME/d/protoc - cd $HOME/d/protoc - export PROTO_ZIP="protoc-21.4-osx-x86_64.zip" - curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v21.4/$PROTO_ZIP - unzip $PROTO_ZIP - echo "$HOME/d/protoc/bin" >> $GITHUB_PATH - export PATH=$PATH:$HOME/d/protoc/bin - protoc --version - # TODO: this won't cache anything, which is expensive. Setup this action - # with a OS-dependent path. + submodules: true - name: Setup Rust toolchain - run: | - rustup update stable - rustup toolchain install stable - rustup default stable - rustup component add rustfmt + uses: ./.github/actions/setup-macos-builder - name: Run tests (excluding doctests) shell: bash run: | cargo test --lib --tests --bins --features avro,json,backtrace cd datafusion-cli - cargo test --lib --tests --bins --all-features - env: - # do not produce debug symbols to keep memory usage down - # hardcoding other profile params to avoid profile override values - # More on Cargo profiles https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings - RUSTFLAGS: "-C debuginfo=0 -C opt-level=0 -C incremental=false -C codegen-units=256" - RUST_BACKTRACE: "1" - # avoid rust stack overflows on tpc-ds tests - RUST_MINSTACK: "3000000" + cargo test --lib --tests --bins --all-features test-datafusion-pyarrow: name: cargo test pyarrow (amd64) diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 2320a8c314cf..659843783016 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -51,8 +51,8 @@ use url::Url; /// run and execute SQL statements and commands, against a context with the given print options pub async fn exec_from_commands( ctx: &mut SessionContext, - print_options: &PrintOptions, commands: Vec, + print_options: &PrintOptions, ) { for sql in commands { match exec_and_print(ctx, print_options, sql).await { @@ -105,8 +105,8 @@ pub async fn exec_from_lines( } pub async fn exec_from_files( - files: Vec, ctx: &mut SessionContext, + files: Vec, print_options: &PrintOptions, ) { let files = files diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 563d172f2c95..dcfd28df1cb0 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -216,7 +216,7 @@ pub async fn main() -> Result<()> { if commands.is_empty() && files.is_empty() { if !rc.is_empty() { - exec::exec_from_files(rc, &mut ctx, &print_options).await + exec::exec_from_files(&mut ctx, rc, &print_options).await } // TODO maybe we can have thiserror for cli but for now let's keep it simple return exec::exec_from_repl(&mut ctx, &mut print_options) @@ -225,11 +225,11 @@ pub async fn main() -> Result<()> { } if !files.is_empty() { - exec::exec_from_files(files, &mut ctx, &print_options).await; + exec::exec_from_files(&mut ctx, files, &print_options).await; } if !commands.is_empty() { - exec::exec_from_commands(&mut ctx, &print_options, commands).await; + exec::exec_from_commands(&mut ctx, commands, &print_options).await; } Ok(()) diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index b8594352b585..b382eb34f62c 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -141,7 +141,8 @@ impl PrintOptions { let mut row_count = 0_usize; let mut with_header = true; - while let Some(Ok(batch)) = stream.next().await { + while let Some(maybe_batch) = stream.next().await { + let batch = maybe_batch?; row_count += batch.num_rows(); self.format.print_batches( &mut writer, diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index aae451add9e7..eecb63d3be65 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -62,6 +62,7 @@ cargo run --example csv_sql - [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) - [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) +- [`advanced_udaf.rs`](examples/advanced_udaf.rs): Define and invoke a more complicated User Defined Aggregate Function (UDAF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) - [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs new file mode 100644 index 000000000000..8d5314bfbea5 --- /dev/null +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -0,0 +1,228 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::{ArrayRef, Float32Array}, + record_batch::RecordBatch, +}; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::{cast::as_float64_array, ScalarValue}; +use datafusion_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature}; + +/// This example shows how to use the full AggregateUDFImpl API to implement a user +/// defined aggregate function. As in the `simple_udaf.rs` example, this struct implements +/// a function `accumulator` that returns the `Accumulator` instance. +/// +/// To do so, we must implement the `AggregateUDFImpl` trait. +#[derive(Debug, Clone)] +struct GeoMeanUdf { + signature: Signature, +} + +impl GeoMeanUdf { + /// Create a new instance of the GeoMeanUdf struct + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take one arguments of type f64 + vec![DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + } + } +} + +impl AggregateUDFImpl for GeoMeanUdf { + /// We implement as_any so that we can downcast the AggregateUDFImpl trait object + fn as_any(&self) -> &dyn Any { + self + } + + /// Return the name of this function + fn name(&self) -> &str { + "geo_mean" + } + + /// Return the "signature" of this function -- namely that types of arguments it will take + fn signature(&self) -> &Signature { + &self.signature + } + + /// What is the type of value that will be returned by this function. + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + /// This is the accumulator factory; DataFusion uses it to create new accumulators. + fn accumulator(&self, _arg: &DataType) -> Result> { + Ok(Box::new(GeometricMean::new())) + } + + /// This is the description of the state. accumulator's state() must match the types here. + fn state_type(&self, _return_type: &DataType) -> Result> { + Ok(vec![DataType::Float64, DataType::UInt32]) + } +} + +/// A UDAF has state across multiple rows, and thus we require a `struct` with that state. +#[derive(Debug)] +struct GeometricMean { + n: u32, + prod: f64, +} + +impl GeometricMean { + // how the struct is initialized + pub fn new() -> Self { + GeometricMean { n: 0, prod: 1.0 } + } +} + +// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions +// to use them. +impl Accumulator for GeometricMean { + // This function serializes our state to `ScalarValue`, which DataFusion uses + // to pass this state between execution stages. + // Note that this can be arbitrary data. + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.prod), + ScalarValue::from(self.n), + ]) + } + + // DataFusion expects this function to return the final value of this aggregator. + // in this case, this is the formula of the geometric mean + fn evaluate(&self) -> Result { + let value = self.prod.powf(1.0 / self.n as f64); + Ok(ScalarValue::from(value)) + } + + // DataFusion calls this function to update the accumulator's state for a batch + // of inputs rows. In this case the product is updated with values from the first column + // and the count is updated based on the row count + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = &values[0]; + (0..arr.len()).try_for_each(|index| { + let v = ScalarValue::try_from_array(arr, index)?; + + if let ScalarValue::Float64(Some(value)) = v { + self.prod *= value; + self.n += 1; + } else { + unreachable!("") + } + Ok(()) + }) + } + + // Merge the output of `Self::state()` from other instances of this accumulator + // into this accumulator's state + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + let arr = &states[0]; + (0..arr.len()).try_for_each(|index| { + let v = states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>()?; + if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = + (&v[0], &v[1]) + { + self.prod *= prod; + self.n += n; + } else { + unreachable!("") + } + Ok(()) + }) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} + +// create local session context with an in-memory table +fn create_context() -> Result { + use datafusion::arrow::datatypes::{Field, Schema}; + use datafusion::datasource::MemTable; + // define a schema. + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Float32Array::from(vec![64.0]))], + )?; + + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context()?; + + // create the AggregateUDF + let geometric_mean = AggregateUDF::from(GeoMeanUdf::new()); + ctx.register_udaf(geometric_mean.clone()); + + let sql_df = ctx.sql("SELECT geo_mean(a) FROM t").await?; + sql_df.show().await?; + + // get a DataFrame from the context + // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0. + let df = ctx.table("t").await?; + + // perform the aggregation + let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?; + + // note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature. + + // execute the query + let results = df.collect().await?; + + // downcast the array to the expected type + let result = as_float64_array(results[0].column(0))?; + + // verify that the calculation is correct + assert!((result.value(0) - 8.0).abs() < f64::EPSILON); + println!("The geometric mean of [2,4,8,64] is {}", result.value(0)); + + Ok(()) +} diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs index 6ebf88a0b671..3e7dd2e2af08 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/advanced_udf.rs @@ -31,7 +31,9 @@ use arrow::datatypes::Float64Type; use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::{internal_err, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; +use datafusion_expr::{ + ColumnarValue, FuncMonotonicity, ScalarUDF, ScalarUDFImpl, Signature, +}; use std::sync::Arc; /// This example shows how to use the full ScalarUDFImpl API to implement a user @@ -40,6 +42,7 @@ use std::sync::Arc; /// the power of the second argument `a^b`. /// /// To do so, we must implement the `ScalarUDFImpl` trait. +#[derive(Debug, Clone)] struct PowUdf { signature: Signature, aliases: Vec, @@ -183,6 +186,10 @@ impl ScalarUDFImpl for PowUdf { fn aliases(&self) -> &[String] { &self.aliases } + + fn monotonicity(&self) -> Result> { + Ok(Some(vec![Some(true)])) + } } /// In this example we register `PowUdf` as a user defined function diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 91869d80a41a..f46031434fc9 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -34,6 +34,7 @@ use datafusion_expr::{ /// a function `partition_evaluator` that returns the `MyPartitionEvaluator` instance. /// /// To do so, we must implement the `WindowUDFImpl` trait. +#[derive(Debug, Clone)] struct SmoothItUdf { signature: Signature, } diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 715e1ff2dce6..19e70dc419e4 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -254,5 +254,5 @@ pub fn physical_expr(schema: &Schema, expr: Expr) -> Result, default = None /// If true, filter expressions are be applied during the parquet decoding operation to - /// reduce the number of rows decoded + /// reduce the number of rows decoded. This optimization is sometimes called "late materialization". pub pushdown_filters: bool, default = false /// If true, filter expressions evaluated during the parquet decoding operation @@ -350,7 +366,9 @@ config_namespace! { /// default parquet writer setting pub max_statistics_size: Option, default = None - /// Sets maximum number of rows in a row group + /// Target maximum number of rows in each row group (defaults to 1M + /// rows). Writing larger row groups requires more memory to write, but + /// can get better compression and be faster to read. pub max_row_group_size: usize, default = 1024 * 1024 /// Sets "created by" property @@ -414,6 +432,10 @@ config_namespace! { config_namespace! { /// Options related to aggregate execution + /// + /// See also: [`SessionConfig`] + /// + /// [`SessionConfig`]: https://docs.rs/datafusion/latest/datafusion/prelude/struct.SessionConfig.html pub struct AggregateOptions { /// Specifies the threshold for using `ScalarValue`s to update /// accumulators during high-cardinality aggregations for each input batch. @@ -431,6 +453,10 @@ config_namespace! { config_namespace! { /// Options related to query optimization + /// + /// See also: [`SessionConfig`] + /// + /// [`SessionConfig`]: https://docs.rs/datafusion/latest/datafusion/prelude/struct.SessionConfig.html pub struct OptimizerOptions { /// When set to true, the optimizer will push a limit operation into /// grouped aggregations which have no aggregate expressions, as a soft limit, @@ -539,6 +565,10 @@ config_namespace! { config_namespace! { /// Options controlling explain output + /// + /// See also: [`SessionConfig`] + /// + /// [`SessionConfig`]: https://docs.rs/datafusion/latest/datafusion/prelude/struct.SessionConfig.html pub struct ExplainOptions { /// When set to true, the explain statement will only print logical plans pub logical_plan_only: bool, default = false diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index d6e4490cec4c..85b97aac037d 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -26,6 +26,7 @@ use std::sync::Arc; use crate::error::{ unqualified_field_not_found, DataFusionError, Result, SchemaError, _plan_err, + _schema_err, }; use crate::{ field_not_found, Column, FunctionalDependencies, OwnedTableReference, TableReference, @@ -141,11 +142,9 @@ impl DFSchema { if let Some(qualifier) = field.qualifier() { qualified_names.insert((qualifier, field.name())); } else if !unqualified_names.insert(field.name()) { - return Err(DataFusionError::SchemaError( - SchemaError::DuplicateUnqualifiedField { - name: field.name().to_string(), - }, - )); + return _schema_err!(SchemaError::DuplicateUnqualifiedField { + name: field.name().to_string(), + }); } } @@ -159,14 +158,12 @@ impl DFSchema { qualified_names.sort(); for (qualifier, name) in &qualified_names { if unqualified_names.contains(name) { - return Err(DataFusionError::SchemaError( - SchemaError::AmbiguousReference { - field: Column { - relation: Some((*qualifier).clone()), - name: name.to_string(), - }, - }, - )); + return _schema_err!(SchemaError::AmbiguousReference { + field: Column { + relation: Some((*qualifier).clone()), + name: name.to_string(), + } + }); } } Ok(Self { @@ -230,9 +227,9 @@ impl DFSchema { for field in other_schema.fields() { // skip duplicate columns let duplicated_field = match field.qualifier() { - Some(q) => self.field_with_name(Some(q), field.name()).is_ok(), + Some(q) => self.has_column_with_qualified_name(q, field.name()), // for unqualified columns, check as unqualified name - None => self.field_with_unqualified_name(field.name()).is_ok(), + None => self.has_column_with_unqualified_name(field.name()), }; if !duplicated_field { self.fields.push(field.clone()); @@ -392,14 +389,12 @@ impl DFSchema { if fields_without_qualifier.len() == 1 { Ok(fields_without_qualifier[0]) } else { - Err(DataFusionError::SchemaError( - SchemaError::AmbiguousReference { - field: Column { - relation: None, - name: name.to_string(), - }, + _schema_err!(SchemaError::AmbiguousReference { + field: Column { + relation: None, + name: name.to_string(), }, - )) + }) } } } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index e58faaa15096..331f5910d7e5 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -47,54 +47,82 @@ pub type GenericError = Box; #[derive(Debug)] pub enum DataFusionError { /// Error returned by arrow. + /// /// 2nd argument is for optional backtrace ArrowError(ArrowError, Option), - /// Wraps an error from the Parquet crate + /// Error when reading / writing Parquet data. #[cfg(feature = "parquet")] ParquetError(ParquetError), - /// Wraps an error from the Avro crate + /// Error when reading Avro data. #[cfg(feature = "avro")] AvroError(AvroError), - /// Wraps an error from the object_store crate + /// Error when reading / writing to / from an object_store (e.g. S3 or LocalFile) #[cfg(feature = "object_store")] ObjectStore(object_store::Error), - /// Error associated to I/O operations and associated traits. + /// Error when an I/O operation fails IoError(io::Error), - /// Error returned when SQL is syntactically incorrect. + /// Error when SQL is syntactically incorrect. + /// /// 2nd argument is for optional backtrace SQL(ParserError, Option), - /// Error returned on a branch that we know it is possible - /// but to which we still have no implementation for. - /// Often, these errors are tracked in our issue tracker. + /// Error when a feature is not yet implemented. + /// + /// These errors are sometimes returned for features that are still in + /// development and are not entirely complete. Often, these errors are + /// tracked in our issue tracker. NotImplemented(String), - /// Error returned as a consequence of an error in DataFusion. - /// This error should not happen in normal usage of DataFusion. + /// Error due to bugs in DataFusion + /// + /// This error should not happen in normal usage of DataFusion. It results + /// from something that wasn't expected/anticipated by the implementation + /// and that is most likely a bug (the error message even encourages users + /// to open a bug report). A user should not be able to trigger internal + /// errors under normal circumstances by feeding in malformed queries, bad + /// data, etc. + /// + /// Note that I/O errors (or any error that happens due to external systems) + /// do NOT fall under this category. See other variants such as + /// [`Self::IoError`] and [`Self::External`]. /// - /// DataFusions has internal invariants that the compiler is not - /// always able to check. This error is raised when one of those - /// invariants is not verified during execution. + /// DataFusions has internal invariants that the compiler is not always able + /// to check. This error is raised when one of those invariants does not + /// hold for some reason. Internal(String), - /// This error happens whenever a plan is not valid. Examples include - /// impossible casts. + /// Error during planning of the query. + /// + /// This error happens when the user provides a bad query or plan, for + /// example the user attempts to call a function that doesn't exist, or if + /// the types of a function call are not supported. Plan(String), - /// This error happens when an invalid or unsupported option is passed - /// in a SQL statement + /// Error for invalid or unsupported configuration options. Configuration(String), - /// This error happens with schema-related errors, such as schema inference not possible - /// and non-unique column names. - SchemaError(SchemaError), - /// Error returned during execution of the query. - /// Examples include files not found, errors in parsing certain types. + /// Error when there is a problem with the query related to schema. + /// + /// This error can be returned in cases such as when schema inference is not + /// possible and when column names are not unique. + /// + /// 2nd argument is for optional backtrace + /// Boxing the optional backtrace to prevent + SchemaError(SchemaError, Box>), + /// Error during execution of the query. + /// + /// This error is returned when an error happens during execution due to a + /// malformed input. For example, the user passed malformed arguments to a + /// SQL method, opened a CSV file that is broken, or tried to divide an + /// integer by zero. Execution(String), - /// This error is thrown when a consumer cannot acquire memory from the Memory Manager - /// we can just cancel the execution of the partition. + /// Error when resources (such as memory of scratch disk space) are exhausted. + /// + /// This error is thrown when a consumer cannot acquire additional memory + /// or other resources needed to execute the query from the Memory Manager. ResourcesExhausted(String), /// Errors originating from outside DataFusion's core codebase. + /// /// For example, a custom S3Error from the crate datafusion-objectstore-s3 External(GenericError), /// Error with additional context Context(String, Box), - /// Errors originating from either mapping LogicalPlans to/from Substrait plans + /// Errors from either mapping LogicalPlans to/from Substrait plans /// or serializing/deserializing protobytes to Substrait plans Substrait(String), } @@ -125,34 +153,6 @@ pub enum SchemaError { }, } -/// Create a "field not found" DataFusion::SchemaError -pub fn field_not_found>( - qualifier: Option, - name: &str, - schema: &DFSchema, -) -> DataFusionError { - DataFusionError::SchemaError(SchemaError::FieldNotFound { - field: Box::new(Column::new(qualifier, name)), - valid_fields: schema - .fields() - .iter() - .map(|f| f.qualified_column()) - .collect(), - }) -} - -/// Convenience wrapper over [`field_not_found`] for when there is no qualifier -pub fn unqualified_field_not_found(name: &str, schema: &DFSchema) -> DataFusionError { - DataFusionError::SchemaError(SchemaError::FieldNotFound { - field: Box::new(Column::new_unqualified(name)), - valid_fields: schema - .fields() - .iter() - .map(|f| f.qualified_column()) - .collect(), - }) -} - impl Display for SchemaError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -298,7 +298,7 @@ impl Display for DataFusionError { write!(f, "IO error: {desc}") } DataFusionError::SQL(ref desc, ref backtrace) => { - let backtrace = backtrace.clone().unwrap_or("".to_owned()); + let backtrace: String = backtrace.clone().unwrap_or("".to_owned()); write!(f, "SQL error: {desc:?}{backtrace}") } DataFusionError::Configuration(ref desc) => { @@ -314,8 +314,10 @@ impl Display for DataFusionError { DataFusionError::Plan(ref desc) => { write!(f, "Error during planning: {desc}") } - DataFusionError::SchemaError(ref desc) => { - write!(f, "Schema error: {desc}") + DataFusionError::SchemaError(ref desc, ref backtrace) => { + let backtrace: &str = + &backtrace.as_ref().clone().unwrap_or("".to_owned()); + write!(f, "Schema error: {desc}{backtrace}") } DataFusionError::Execution(ref desc) => { write!(f, "Execution error: {desc}") @@ -356,7 +358,7 @@ impl Error for DataFusionError { DataFusionError::Internal(_) => None, DataFusionError::Configuration(_) => None, DataFusionError::Plan(_) => None, - DataFusionError::SchemaError(e) => Some(e), + DataFusionError::SchemaError(e, _) => Some(e), DataFusionError::Execution(_) => None, DataFusionError::ResourcesExhausted(_) => None, DataFusionError::External(e) => Some(e.as_ref()), @@ -556,12 +558,63 @@ macro_rules! arrow_err { }; } +// Exposes a macro to create `DataFusionError::SchemaError` with optional backtrace +#[macro_export] +macro_rules! schema_datafusion_err { + ($ERR:expr) => { + DataFusionError::SchemaError( + $ERR, + Box::new(Some(DataFusionError::get_back_trace())), + ) + }; +} + +// Exposes a macro to create `Err(DataFusionError::SchemaError)` with optional backtrace +#[macro_export] +macro_rules! schema_err { + ($ERR:expr) => { + Err(DataFusionError::SchemaError( + $ERR, + Box::new(Some(DataFusionError::get_back_trace())), + )) + }; +} + // To avoid compiler error when using macro in the same crate: // macros from the current crate cannot be referred to by absolute paths pub use internal_datafusion_err as _internal_datafusion_err; pub use internal_err as _internal_err; pub use not_impl_err as _not_impl_err; pub use plan_err as _plan_err; +pub use schema_err as _schema_err; + +/// Create a "field not found" DataFusion::SchemaError +pub fn field_not_found>( + qualifier: Option, + name: &str, + schema: &DFSchema, +) -> DataFusionError { + schema_datafusion_err!(SchemaError::FieldNotFound { + field: Box::new(Column::new(qualifier, name)), + valid_fields: schema + .fields() + .iter() + .map(|f| f.qualified_column()) + .collect(), + }) +} + +/// Convenience wrapper over [`field_not_found`] for when there is no qualifier +pub fn unqualified_field_not_found(name: &str, schema: &DFSchema) -> DataFusionError { + schema_datafusion_err!(SchemaError::FieldNotFound { + field: Box::new(Column::new_unqualified(name)), + valid_fields: schema + .fields() + .iter() + .map(|f| f.qualified_column()) + .collect(), + }) +} #[cfg(test)] mod test { diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 5c36f41a6e42..8dcc00ca1c29 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -214,22 +214,19 @@ fn hash_struct_array( hashes_buffer: &mut [u64], ) -> Result<()> { let nulls = array.nulls(); - let num_columns = array.num_columns(); + let row_len = array.len(); - // Skip null columns - let valid_indices: Vec = if let Some(nulls) = nulls { + let valid_row_indices: Vec = if let Some(nulls) = nulls { nulls.valid_indices().collect() } else { - (0..num_columns).collect() + (0..row_len).collect() }; // Create hashes for each row that combines the hashes over all the column at that row. - // array.len() is the number of rows. - let mut values_hashes = vec![0u64; array.len()]; + let mut values_hashes = vec![0u64; row_len]; create_hashes(array.columns(), random_state, &mut values_hashes)?; - // Skip the null columns, nulls should get hash value 0. - for i in valid_indices { + for i in valid_row_indices { let hash = &mut hashes_buffer[i]; *hash = combine_hashes(*hash, values_hashes[i]); } @@ -601,6 +598,39 @@ mod tests { assert_eq!(hashes[4], hashes[5]); } + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_struct_arrays_more_column_than_row() { + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("bool", DataType::Boolean, false)), + Arc::new(BooleanArray::from(vec![false, false])) as ArrayRef, + ), + ( + Arc::new(Field::new("i32-1", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef, + ), + ( + Arc::new(Field::new("i32-2", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef, + ), + ( + Arc::new(Field::new("i32-3", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef, + ), + ]); + + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_valid(1)); + + let array = Arc::new(struct_array) as ArrayRef; + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array.len()]; + create_hashes(&[array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 48878aa9bd99..8820ca9942fc 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -21,6 +21,7 @@ use std::borrow::Borrow; use std::cmp::Ordering; use std::collections::HashSet; use std::convert::{Infallible, TryInto}; +use std::hash::Hash; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; @@ -142,13 +143,13 @@ pub enum ScalarValue { /// Fixed size list scalar. /// /// The array must be a FixedSizeListArray with length 1. - FixedSizeList(ArrayRef), + FixedSizeList(Arc), /// Represents a single element of a [`ListArray`] as an [`ArrayRef`] /// /// The array must be a ListArray with length 1. - List(ArrayRef), + List(Arc), /// The array must be a LargeListArray with length 1. - LargeList(ArrayRef), + LargeList(Arc), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 @@ -360,45 +361,13 @@ impl PartialOrd for ScalarValue { (FixedSizeBinary(_, _), _) => None, (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), (LargeBinary(_), _) => None, - (List(arr1), List(arr2)) - | (FixedSizeList(arr1), FixedSizeList(arr2)) - | (LargeList(arr1), LargeList(arr2)) => { - // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 - assert_eq!(arr1.len(), 1); - assert_eq!(arr2.len(), 1); - - if arr1.data_type() != arr2.data_type() { - return None; - } - - fn first_array_for_list(arr: &ArrayRef) -> ArrayRef { - if let Some(arr) = arr.as_list_opt::() { - arr.value(0) - } else if let Some(arr) = arr.as_list_opt::() { - arr.value(0) - } else if let Some(arr) = arr.as_fixed_size_list_opt() { - arr.value(0) - } else { - unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") - } - } - - let arr1 = first_array_for_list(arr1); - let arr2 = first_array_for_list(arr2); - - let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; - - for j in 0..lt_res.len() { - if lt_res.is_valid(j) && lt_res.value(j) { - return Some(Ordering::Less); - } - if eq_res.is_valid(j) && !eq_res.value(j) { - return Some(Ordering::Greater); - } - } - - Some(Ordering::Equal) + // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 + (List(arr1), List(arr2)) => partial_cmp_list(arr1.as_ref(), arr2.as_ref()), + (FixedSizeList(arr1), FixedSizeList(arr2)) => { + partial_cmp_list(arr1.as_ref(), arr2.as_ref()) + } + (LargeList(arr1), LargeList(arr2)) => { + partial_cmp_list(arr1.as_ref(), arr2.as_ref()) } (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), @@ -464,6 +433,44 @@ impl PartialOrd for ScalarValue { } } +/// List/LargeList/FixedSizeList scalars always have a single element +/// array. This function returns that array +fn first_array_for_list(arr: &dyn Array) -> ArrayRef { + assert_eq!(arr.len(), 1); + if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_fixed_size_list_opt() { + arr.value(0) + } else { + unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") + } +} + +/// Compares two List/LargeList/FixedSizeList scalars +fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { + if arr1.data_type() != arr2.data_type() { + return None; + } + let arr1 = first_array_for_list(arr1); + let arr2 = first_array_for_list(arr2); + + let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); + } + } + + Some(Ordering::Equal) +} + impl Eq for ScalarValue {} //Float wrapper over f32/f64. Just because we cannot build std::hash::Hash for floats directly we have to do it through type wrapper @@ -517,14 +524,14 @@ impl std::hash::Hash for ScalarValue { Binary(v) => v.hash(state), FixedSizeBinary(_, v) => v.hash(state), LargeBinary(v) => v.hash(state), - List(arr) | LargeList(arr) | FixedSizeList(arr) => { - let arrays = vec![arr.to_owned()]; - let hashes_buffer = &mut vec![0; arr.len()]; - let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); - let hashes = - create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); - // Hash back to std::hash::Hasher - hashes.hash(state); + List(arr) => { + hash_list(arr.to_owned() as ArrayRef, state); + } + LargeList(arr) => { + hash_list(arr.to_owned() as ArrayRef, state); + } + FixedSizeList(arr) => { + hash_list(arr.to_owned() as ArrayRef, state); } Date32(v) => v.hash(state), Date64(v) => v.hash(state), @@ -557,6 +564,15 @@ impl std::hash::Hash for ScalarValue { } } +fn hash_list(arr: ArrayRef, state: &mut H) { + let arrays = vec![arr.to_owned()]; + let hashes_buffer = &mut vec![0; arr.len()]; + let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); + let hashes = create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); + // Hash back to std::hash::Hasher + hashes.hash(state); +} + /// Return a reference to the values array and the index into it for a /// dictionary array /// @@ -942,9 +958,9 @@ impl ScalarValue { ScalarValue::Binary(_) => DataType::Binary, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), + ScalarValue::List(arr) => arr.data_type().to_owned(), + ScalarValue::LargeList(arr) => arr.data_type().to_owned(), + ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second), @@ -1147,9 +1163,9 @@ impl ScalarValue { ScalarValue::LargeBinary(v) => v.is_none(), // arr.len() should be 1 for a list scalar, but we don't seem to // enforce that anywhere, so we still check against array length. - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), + ScalarValue::List(arr) => arr.len() == arr.null_count(), + ScalarValue::LargeList(arr) => arr.len() == arr.null_count(), + ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), ScalarValue::Time32Second(v) => v.is_none(), @@ -1695,17 +1711,16 @@ impl ScalarValue { /// ScalarValue::Int32(Some(2)) /// ]; /// - /// let array = ScalarValue::new_list(&scalars, &DataType::Int32); - /// let result = as_list_array(&array).unwrap(); + /// let result = ScalarValue::new_list(&scalars, &DataType::Int32); /// /// let expected = ListArray::from_iter_primitive::( /// vec![ /// Some(vec![Some(1), None, Some(2)]) /// ]); /// - /// assert_eq!(result, &expected); + /// assert_eq!(*result, expected); /// ``` - pub fn new_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef { + pub fn new_list(values: &[ScalarValue], data_type: &DataType) -> Arc { let values = if values.is_empty() { new_empty_array(data_type) } else { @@ -1730,17 +1745,19 @@ impl ScalarValue { /// ScalarValue::Int32(Some(2)) /// ]; /// - /// let array = ScalarValue::new_large_list(&scalars, &DataType::Int32); - /// let result = as_large_list_array(&array).unwrap(); + /// let result = ScalarValue::new_large_list(&scalars, &DataType::Int32); /// /// let expected = LargeListArray::from_iter_primitive::( /// vec![ /// Some(vec![Some(1), None, Some(2)]) /// ]); /// - /// assert_eq!(result, &expected); + /// assert_eq!(*result, expected); /// ``` - pub fn new_large_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef { + pub fn new_large_list( + values: &[ScalarValue], + data_type: &DataType, + ) -> Arc { let values = if values.is_empty() { new_empty_array(data_type) } else { @@ -1876,14 +1893,14 @@ impl ScalarValue { .collect::(), ), }, - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => { - let arrays = std::iter::repeat(arr.as_ref()) - .take(size) - .collect::>(); - arrow::compute::concat(arrays.as_slice()) - .map_err(|e| arrow_datafusion_err!(e))? + ScalarValue::List(arr) => { + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? + } + ScalarValue::LargeList(arr) => { + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? + } + ScalarValue::FixedSizeList(arr) => { + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) @@ -2040,6 +2057,11 @@ impl ScalarValue { } } + fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result { + let arrays = std::iter::repeat(arr).take(size).collect::>(); + arrow::compute::concat(arrays.as_slice()).map_err(|e| arrow_datafusion_err!(e)) + } + /// Retrieve ScalarValue for each row in `array` /// /// Example @@ -2433,11 +2455,14 @@ impl ScalarValue { ScalarValue::LargeBinary(val) => { eq_array_primitive!(array, index, LargeBinaryArray, val)? } - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => { - let right = array.slice(index, 1); - arr == &right + ScalarValue::List(arr) => { + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) + } + ScalarValue::LargeList(arr) => { + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) + } + ScalarValue::FixedSizeList(arr) => { + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } ScalarValue::Date32(val) => { eq_array_primitive!(array, index, Date32Array, val)? @@ -2515,6 +2540,11 @@ impl ScalarValue { }) } + fn eq_array_list(arr1: &ArrayRef, arr2: &ArrayRef, index: usize) -> bool { + let right = arr2.slice(index, 1); + arr1 == &right + } + /// Estimate size if bytes including `Self`. For values with internal containers such as `String` /// includes the allocated size (`capacity`) rather than the current length (`len`) pub fn size(&self) -> usize { @@ -2561,9 +2591,9 @@ impl ScalarValue { | ScalarValue::LargeBinary(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), + ScalarValue::List(arr) => arr.get_array_memory_size(), + ScalarValue::LargeList(arr) => arr.get_array_memory_size(), + ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(vals, fields) => { vals.as_ref() .map(|vals| { @@ -2865,14 +2895,19 @@ impl TryFrom<&DataType> for ScalarValue { Box::new(value_type.as_ref().try_into()?), ), // `ScalaValue::List` contains single element `ListArray`. - DataType::List(field) => ScalarValue::List(new_null_array( - &DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - ))), - 1, - )), + DataType::List(field) => ScalarValue::List( + new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + ))), + 1, + ) + .as_list::() + .to_owned() + .into(), + ), DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()), DataType::Null => ScalarValue::Null, _ => { @@ -2937,16 +2972,9 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => { - // ScalarValue List should always have a single element - assert_eq!(arr.len(), 1); - let options = FormatOptions::default().with_display_error(true); - let formatter = ArrayFormatter::try_new(arr, &options).unwrap(); - let value_formatter = formatter.value(0); - write!(f, "{value_formatter}")? - } + ScalarValue::List(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, + ScalarValue::LargeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, + ScalarValue::FixedSizeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, ScalarValue::Date32(e) => format_option!(f, e)?, ScalarValue::Date64(e) => format_option!(f, e)?, ScalarValue::Time32Second(e) => format_option!(f, e)?, @@ -2979,6 +3007,16 @@ impl fmt::Display for ScalarValue { } } +fn fmt_list(arr: ArrayRef, f: &mut fmt::Formatter) -> fmt::Result { + // ScalarValue List, LargeList, FixedSizeList should always have a single element + assert_eq!(arr.len(), 1); + let options = FormatOptions::default().with_display_error(true); + let formatter = + ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options).unwrap(); + let value_formatter = formatter.value(0); + write!(f, "{value_formatter}") +} + impl fmt::Debug for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -3182,15 +3220,14 @@ mod tests { ScalarValue::from("data-fusion"), ]; - let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); + let result = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); let expected = array_into_list_array(Arc::new(StringArray::from(vec![ "rust", "arrow", "data-fusion", ]))); - let result = as_list_array(&array); - assert_eq!(result, &expected); + assert_eq!(*result, expected); } fn build_list( @@ -3226,9 +3263,9 @@ mod tests { }; if O::IS_LARGE { - ScalarValue::LargeList(arr) + ScalarValue::LargeList(arr.as_list::().to_owned().into()) } else { - ScalarValue::List(arr) + ScalarValue::List(arr.as_list::().to_owned().into()) } }) .collect() @@ -3311,18 +3348,16 @@ mod tests { ])); let fsl_array: ArrayRef = - Arc::new(FixedSizeListArray::from_iter_primitive::( - vec![ - Some(vec![Some(0), Some(1), Some(2)]), - None, - Some(vec![Some(3), None, Some(5)]), - ], - 3, - )); + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + ])); for arr in [list_array, fsl_array] { for i in 0..arr.len() { - let scalar = ScalarValue::List(arr.slice(i, 1)); + let scalar = + ScalarValue::List(arr.slice(i, 1).as_list::().to_owned().into()); assert!(scalar.eq_array(&arr, i).unwrap()); } } @@ -3676,8 +3711,7 @@ mod tests { #[test] fn scalar_list_null_to_array() { - let list_array_ref = ScalarValue::new_list(&[], &DataType::UInt64); - let list_array = as_list_array(&list_array_ref); + let list_array = ScalarValue::new_list(&[], &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 0); @@ -3685,8 +3719,7 @@ mod tests { #[test] fn scalar_large_list_null_to_array() { - let list_array_ref = ScalarValue::new_large_list(&[], &DataType::UInt64); - let list_array = as_large_list_array(&list_array_ref); + let list_array = ScalarValue::new_large_list(&[], &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 0); @@ -3699,8 +3732,7 @@ mod tests { ScalarValue::UInt64(None), ScalarValue::UInt64(Some(101)), ]; - let list_array_ref = ScalarValue::new_list(&values, &DataType::UInt64); - let list_array = as_list_array(&list_array_ref); + let list_array = ScalarValue::new_list(&values, &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 3); @@ -3720,8 +3752,7 @@ mod tests { ScalarValue::UInt64(None), ScalarValue::UInt64(Some(101)), ]; - let list_array_ref = ScalarValue::new_large_list(&values, &DataType::UInt64); - let list_array = as_large_list_array(&list_array_ref); + let list_array = ScalarValue::new_large_list(&values, &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 3); @@ -3959,10 +3990,15 @@ mod tests { let data_type = &data_type; let scalar: ScalarValue = data_type.try_into().unwrap(); - let expected = ScalarValue::List(new_null_array( - &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - 1, - )); + let expected = ScalarValue::List( + new_null_array( + &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + 1, + ) + .as_list::() + .to_owned() + .into(), + ); assert_eq!(expected, scalar) } @@ -3977,14 +4013,19 @@ mod tests { let data_type = &data_type; let scalar: ScalarValue = data_type.try_into().unwrap(); - let expected = ScalarValue::List(new_null_array( - &DataType::List(Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - ))), - 1, - )); + let expected = ScalarValue::List( + new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ))), + 1, + ) + .as_list::() + .to_owned() + .into(), + ); assert_eq!(expected, scalar) } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 9de6a7f7d6a0..c2e8c2b44531 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -37,7 +37,7 @@ path = "src/lib.rs" # Used to enable the avro format avro = ["apache-avro", "num-traits", "datafusion-common/avro"] backtrace = ["datafusion-common/backtrace"] -compression = ["xz2", "bzip2", "flate2", "zstd", "async-compression"] +compression = ["xz2", "bzip2", "flate2", "zstd", "async-compression", "tokio-util"] crypto_expressions = ["datafusion-physical-expr/crypto_expressions", "datafusion-optimizer/crypto_expressions"] default = ["crypto_expressions", "encoding_expressions", "regex_expressions", "unicode_expressions", "compression", "parquet"] encoding_expressions = ["datafusion-physical-expr/encoding_expressions"] @@ -87,8 +87,8 @@ pin-project-lite = "^0.2.7" rand = { workspace = true } sqlparser = { workspace = true } tempfile = { workspace = true } -tokio = { version = "1.28", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } -tokio-util = { version = "0.7.4", features = ["io"] } +tokio = { version = "1.28", features = ["macros", "rt", "sync"] } +tokio-util = { version = "0.7.4", features = ["io"], optional = true } url = { workspace = true } uuid = { version = "1.0", features = ["v4"] } xz2 = { version = "0.1", optional = true } @@ -113,6 +113,7 @@ rust_decimal = { version = "1.27.0", features = ["tokio-pg"] } serde_json = { workspace = true } test-utils = { path = "../../test-utils" } thiserror = { workspace = true } +tokio = { version = "1.28", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-postgres = "0.7.7" [target.'cfg(not(target_os = "windows"))'.dev-dependencies] nix = { version = "0.27.1", features = ["fs"] } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 5a8c706e32cd..f15f1e9ba6fb 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -802,6 +802,7 @@ impl DataFrame { /// Executes this DataFrame and returns a stream over a single partition /// + /// # Example /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; @@ -813,6 +814,11 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` + /// + /// # Aborting Execution + /// + /// Dropping the stream will abort the execution of the query, and free up + /// any allocated resources pub async fn execute_stream(self) -> Result { let task_ctx = Arc::new(self.task_ctx()); let plan = self.create_physical_plan().await?; @@ -841,6 +847,7 @@ impl DataFrame { /// Executes this DataFrame and returns one stream per partition. /// + /// # Example /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; @@ -852,6 +859,10 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` + /// # Aborting Execution + /// + /// Dropping the stream will abort the execution of the query, and free up + /// any allocated resources pub async fn execute_stream_partitioned( self, ) -> Result> { @@ -1175,7 +1186,7 @@ impl DataFrame { let field_to_rename = match self.plan.schema().field_from_column(&old_column) { Ok(field) => field, // no-op if field not found - Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { .. })) => { + Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }, _)) => { return Ok(self) } Err(err) => return Err(err), diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 7a0af3ff0809..9cae6675e825 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -423,9 +423,8 @@ impl CsvSerializer { } } -#[async_trait] impl BatchSerializer for CsvSerializer { - async fn serialize(&self, batch: RecordBatch, initial: bool) -> Result { + fn serialize(&self, batch: RecordBatch, initial: bool) -> Result { let mut buffer = Vec::with_capacity(4096); let builder = self.builder.clone(); let header = self.header && initial; @@ -829,7 +828,7 @@ mod tests { .await?; let batch = concat_batches(&batches[0].schema(), &batches)?; let serializer = CsvSerializer::new(); - let bytes = serializer.serialize(batch, true).await?; + let bytes = serializer.serialize(batch, true)?; assert_eq!( "c2,c3\n2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n", String::from_utf8(bytes.into()).unwrap() @@ -853,7 +852,7 @@ mod tests { .await?; let batch = concat_batches(&batches[0].schema(), &batches)?; let serializer = CsvSerializer::new().with_header(false); - let bytes = serializer.serialize(batch, true).await?; + let bytes = serializer.serialize(batch, true)?; assert_eq!( "2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n", String::from_utf8(bytes.into()).unwrap() diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 8c02955ad363..0f6d3648d120 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -204,9 +204,8 @@ impl JsonSerializer { } } -#[async_trait] impl BatchSerializer for JsonSerializer { - async fn serialize(&self, batch: RecordBatch, _initial: bool) -> Result { + fn serialize(&self, batch: RecordBatch, _initial: bool) -> Result { let mut buffer = Vec::with_capacity(4096); let mut writer = json::LineDelimitedWriter::new(&mut buffer); writer.write(&batch)?; diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/core/src/datasource/file_format/write/mod.rs index c481f2accf19..410a32a19cc1 100644 --- a/datafusion/core/src/datasource/file_format/write/mod.rs +++ b/datafusion/core/src/datasource/file_format/write/mod.rs @@ -29,7 +29,6 @@ use crate::error::Result; use arrow_array::RecordBatch; use datafusion_common::DataFusionError; -use async_trait::async_trait; use bytes::Bytes; use futures::future::BoxFuture; use object_store::path::Path; @@ -144,12 +143,11 @@ impl AsyncWrite for AbortableWrite { } /// A trait that defines the methods required for a RecordBatch serializer. -#[async_trait] pub trait BatchSerializer: Sync + Send { /// Asynchronously serializes a `RecordBatch` and returns the serialized bytes. /// Parameter `initial` signals whether the given batch is the first batch. /// This distinction is important for certain serializers (like CSV). - async fn serialize(&self, batch: RecordBatch, initial: bool) -> Result; + fn serialize(&self, batch: RecordBatch, initial: bool) -> Result; } /// Returns an [`AbortableWrite`] which writes to the given object store location diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index 9b820a15b280..106b4e0d50e5 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -60,7 +60,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( let serializer_clone = serializer.clone(); let handle = tokio::spawn(async move { let num_rows = batch.num_rows(); - let bytes = serializer_clone.serialize(batch, initial).await?; + let bytes = serializer_clone.serialize(batch, initial)?; Ok((num_rows, bytes)) }); if initial { diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 68de55e1a410..a03bcec7abec 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -285,7 +285,7 @@ async fn prune_partitions( // Applies `filter` to `batch` returning `None` on error let do_filter = |filter| -> Option { - let expr = create_physical_expr(filter, &df_schema, &schema, &props).ok()?; + let expr = create_physical_expr(filter, &df_schema, &props).ok()?; expr.evaluate(&batch) .ok()? .into_array(partitions.len()) diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index a7af1bf1be28..de207b6d9019 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -662,12 +662,8 @@ impl TableProvider for ListingTable { let filters = if let Some(expr) = conjunction(filters.to_vec()) { // NOTE: Use the table schema (NOT file schema) here because `expr` may contain references to partition columns. let table_df_schema = self.table_schema.as_ref().clone().to_dfschema()?; - let filters = create_physical_expr( - &expr, - &table_df_schema, - &self.table_schema, - state.execution_props(), - )?; + let filters = + create_physical_expr(&expr, &table_df_schema, state.execution_props())?; Some(filters) } else { None diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index 766dee7de901..6421edf77972 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -103,12 +103,14 @@ impl ListingTableUrl { let s = s.as_ref(); // This is necessary to handle the case of a path starting with a drive letter + #[cfg(not(target_arch = "wasm32"))] if std::path::Path::new(s).is_absolute() { return Self::parse_path(s); } match Url::parse(s) { Ok(url) => Self::try_new(url, None), + #[cfg(not(target_arch = "wasm32"))] Err(url::ParseError::RelativeUrlWithoutBase) => Self::parse_path(s), Err(e) => Err(DataFusionError::External(Box::new(e))), } @@ -146,6 +148,7 @@ impl ListingTableUrl { } /// Creates a new [`ListingTableUrl`] interpreting `s` as a filesystem path + #[cfg(not(target_arch = "wasm32"))] fn parse_path(s: &str) -> Result { let (path, glob) = match split_glob_expression(s) { Some((prefix, glob)) => { @@ -282,6 +285,7 @@ impl ListingTableUrl { } /// Creates a file URL from a potentially relative filesystem path +#[cfg(not(target_arch = "wasm32"))] fn url_from_filesystem_path(s: &str) -> Option { let path = std::path::Path::new(s); let is_dir = match path.exists() { diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index bb4c8313642c..353662397648 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -535,7 +535,6 @@ mod tests { use arrow_schema::Schema; use datafusion_common::{internal_err, DataFusionError, Statistics}; - use async_trait::async_trait; use bytes::Bytes; use futures::StreamExt; @@ -989,9 +988,8 @@ mod tests { bytes: Bytes, } - #[async_trait] impl BatchSerializer for TestSerializer { - async fn serialize(&self, _batch: RecordBatch, _initial: bool) -> Result { + fn serialize(&self, _batch: RecordBatch, _initial: bool) -> Result { Ok(self.bytes.clone()) } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs b/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs index 915fb56680f5..a17a3c6d9752 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs @@ -29,8 +29,10 @@ use crate::physical_plan::metrics::{ pub struct ParquetFileMetrics { /// Number of times the predicate could not be evaluated pub predicate_evaluation_errors: Count, - /// Number of row groups pruned using - pub row_groups_pruned: Count, + /// Number of row groups pruned by bloom filters + pub row_groups_pruned_bloom_filter: Count, + /// Number of row groups pruned by statistics + pub row_groups_pruned_statistics: Count, /// Total number of bytes scanned pub bytes_scanned: Count, /// Total rows filtered out by predicates pushed into parquet scan @@ -54,9 +56,13 @@ impl ParquetFileMetrics { .with_new_label("filename", filename.to_string()) .counter("predicate_evaluation_errors", partition); - let row_groups_pruned = MetricBuilder::new(metrics) + let row_groups_pruned_bloom_filter = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) - .counter("row_groups_pruned", partition); + .counter("row_groups_pruned_bloom_filter", partition); + + let row_groups_pruned_statistics = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .counter("row_groups_pruned_statistics", partition); let bytes_scanned = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) @@ -79,7 +85,8 @@ impl ParquetFileMetrics { Self { predicate_evaluation_errors, - row_groups_pruned, + row_groups_pruned_bloom_filter, + row_groups_pruned_statistics, bytes_scanned, pushdown_rows_filtered, pushdown_eval_time, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 9d81d8d083c2..c2689cfb10a6 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -51,6 +51,7 @@ use datafusion_physical_expr::{ use bytes::Bytes; use futures::future::BoxFuture; use futures::{StreamExt, TryStreamExt}; +use itertools::Itertools; use log::debug; use object_store::path::Path; use object_store::ObjectStore; @@ -278,7 +279,17 @@ impl DisplayAs for ParquetExec { let pruning_predicate_string = self .pruning_predicate .as_ref() - .map(|pre| format!(", pruning_predicate={}", pre.predicate_expr())) + .map(|pre| { + format!( + ", pruning_predicate={}, required_guarantees=[{}]", + pre.predicate_expr(), + pre.literal_guarantees() + .iter() + .map(|item| format!("{}", item)) + .collect_vec() + .join(", ") + ) + }) .unwrap_or_default(); write!(f, "ParquetExec: ")?; @@ -2123,6 +2134,6 @@ mod tests { fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { let df_schema = schema.clone().to_dfschema().unwrap(); let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() + create_physical_expr(expr, &df_schema, &execution_props).unwrap() } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index 151ab5f657b1..3c40509a86d2 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -487,6 +487,6 @@ mod test { fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { let df_schema = schema.clone().to_dfschema().unwrap(); let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() + create_physical_expr(expr, &df_schema, &execution_props).unwrap() } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 24c65423dd4c..c519d41aad01 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -81,7 +81,7 @@ pub(crate) fn prune_row_groups_by_statistics( Ok(values) => { // NB: false means don't scan row group if !values[0] { - metrics.row_groups_pruned.add(1); + metrics.row_groups_pruned_statistics.add(1); continue; } } @@ -159,7 +159,7 @@ pub(crate) async fn prune_row_groups_by_bloom_filters< }; if prune_group { - metrics.row_groups_pruned.add(1); + metrics.row_groups_pruned_bloom_filter.add(1); } else { filtered.push(*idx); } @@ -1010,7 +1010,7 @@ mod tests { fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { let df_schema = schema.clone().to_dfschema().unwrap(); let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() + create_physical_expr(expr, &df_schema, &execution_props).unwrap() } #[tokio::test] @@ -1049,12 +1049,9 @@ mod tests { let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); let expr = col(r#""String""#).in_list( - vec![ - lit("Hello_Not_Exists"), - lit("Hello_Not_Exists2"), - lit("Hello_Not_Exists3"), - lit("Hello_Not_Exist4"), - ], + (1..25) + .map(|i| lit(format!("Hello_Not_Exists{}", i))) + .collect::>(), false, ); let expr = logical2physical(&expr, &schema); diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index d6b7f046f3e3..1e378541b624 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2124,7 +2124,6 @@ mod tests { use crate::test; use crate::test_util::{plan_and_collect, populate_csv_partitions}; use crate::variable::VarType; - use arrow_schema::Schema; use async_trait::async_trait; use datafusion_expr::Expr; use std::env; @@ -2504,7 +2503,6 @@ mod tests { &self, _expr: &Expr, _input_dfschema: &crate::common::DFSchema, - _input_schema: &Schema, _session_state: &SessionState, ) -> Result> { unimplemented!() diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index b3ebbc6e3637..8fc724a22443 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -285,37 +285,38 @@ //! //! ### Logical Plans //! Logical planning yields [`LogicalPlan`] nodes and [`Expr`] -//! expressions which are [`Schema`] aware and represent statements +//! representing expressions which are [`Schema`] aware and represent statements //! independent of how they are physically executed. //! A [`LogicalPlan`] is a Directed Acyclic Graph (DAG) of other //! [`LogicalPlan`]s, each potentially containing embedded [`Expr`]s. //! -//! Examples of working with and executing `Expr`s can be found in the +//! [`Expr`]s can be rewritten using the [`TreeNode`] API and simplified using +//! [`ExprSimplifier`]. Examples of working with and executing `Expr`s can be found in the //! [`expr_api`.rs] example //! +//! [`TreeNode`]: datafusion_common::tree_node::TreeNode +//! [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier //! [`expr_api`.rs]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs //! //! ### Physical Plans //! //! An [`ExecutionPlan`] (sometimes referred to as a "physical plan") //! is a plan that can be executed against data. It a DAG of other -//! [`ExecutionPlan`]s each potentially containing expressions of the -//! following types: +//! [`ExecutionPlan`]s each potentially containing expressions that implement the +//! [`PhysicalExpr`] trait. //! -//! 1. [`PhysicalExpr`]: Scalar functions -//! -//! 2. [`AggregateExpr`]: Aggregate functions -//! -//! 2. [`WindowExpr`]: Window functions -//! -//! Compared to a [`LogicalPlan`], an [`ExecutionPlan`] has concrete +//! Compared to a [`LogicalPlan`], an [`ExecutionPlan`] has additional concrete //! information about how to perform calculations (e.g. hash vs merge //! join), and how data flows during execution (e.g. partitioning and //! sortedness). //! +//! [cp_solver] performs range propagation analysis on [`PhysicalExpr`]s and +//! [`PruningPredicate`] can prove certain boolean [`PhysicalExpr`]s used for +//! filtering can never be `true` using additional statistical information. +//! +//! [cp_solver]: crate::physical_expr::intervals::cp_solver +//! [`PruningPredicate`]: crate::physical_optimizer::pruning::PruningPredicate //! [`PhysicalExpr`]: crate::physical_plan::PhysicalExpr -//! [`AggregateExpr`]: crate::physical_plan::AggregateExpr -//! [`WindowExpr`]: crate::physical_plan::WindowExpr //! //! ## Execution //! diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 7359a6463059..61eb2381c63b 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -269,12 +269,13 @@ mod tests { aggr_expr: Vec>, ) -> Arc { let schema = input.schema(); + let n_aggr = aggr_expr.len(); Arc::new( AggregateExec::try_new( AggregateMode::Partial, group_by, aggr_expr, - vec![], + vec![None; n_aggr], input, schema, ) @@ -288,12 +289,13 @@ mod tests { aggr_expr: Vec>, ) -> Arc { let schema = input.schema(); + let n_aggr = aggr_expr.len(); Arc::new( AggregateExec::try_new( AggregateMode::Final, group_by, aggr_expr, - vec![], + vec![None; n_aggr], input, schema, ) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index bf5aa7d02272..a2f530c0e689 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -925,11 +925,9 @@ fn add_hash_on_top( mut input: DistributionContext, hash_exprs: Vec>, n_target: usize, - repartition_beneficial_stats: bool, ) -> Result { - let partition_count = input.plan.output_partitioning().partition_count(); // Early return if hash repartition is unnecessary - if n_target == partition_count && n_target == 1 { + if n_target == 1 { return Ok(input); } @@ -951,12 +949,6 @@ fn add_hash_on_top( // requirements. // - Usage of order preserving variants is not desirable (per the flag // `config.optimizer.prefer_existing_sort`). - if repartition_beneficial_stats { - // Since hashing benefits from partitioning, add a round-robin repartition - // before it: - input = add_roundrobin_on_top(input, n_target)?; - } - let partitioning = Partitioning::Hash(hash_exprs, n_target); let repartition = RepartitionExec::try_new(input.plan.clone(), partitioning)? .with_preserve_order(); @@ -1198,37 +1190,31 @@ fn ensure_distribution( ) .map( |(mut child, requirement, required_input_ordering, would_benefit, maintains)| { - // Don't need to apply when the returned row count is not greater than 1: + // Don't need to apply when the returned row count is not greater than batch size let num_rows = child.plan.statistics()?.num_rows; let repartition_beneficial_stats = if num_rows.is_exact().unwrap_or(false) { num_rows .get_value() .map(|value| value > &batch_size) - .unwrap_or(true) + .unwrap() // safe to unwrap since is_exact() is true } else { true }; - if enable_round_robin + let add_roundrobin = enable_round_robin // Operator benefits from partitioning (e.g. filter): && (would_benefit && repartition_beneficial_stats) - // Unless partitioning doesn't increase the partition count, it is not beneficial: - && child.plan.output_partitioning().partition_count() < target_partitions - { - // When `repartition_file_scans` is set, attempt to increase - // parallelism at the source. - if repartition_file_scans { - if let Some(new_child) = - child.plan.repartitioned(target_partitions, config)? - { - child.plan = new_child; - } + // Unless partitioning increases the partition count, it is not beneficial: + && child.plan.output_partitioning().partition_count() < target_partitions; + + // When `repartition_file_scans` is set, attempt to increase + // parallelism at the source. + if repartition_file_scans && repartition_beneficial_stats { + if let Some(new_child) = + child.plan.repartitioned(target_partitions, config)? + { + child.plan = new_child; } - // Increase parallelism by adding round-robin repartitioning - // on top of the operator. Note that we only do this if the - // partition count is not already equal to the desired partition - // count. - child = add_roundrobin_on_top(child, target_partitions)?; } // Satisfy the distribution requirement if it is unmet. @@ -1237,14 +1223,20 @@ fn ensure_distribution( child = add_spm_on_top(child); } Distribution::HashPartitioned(exprs) => { - child = add_hash_on_top( - child, - exprs.to_vec(), - target_partitions, - repartition_beneficial_stats, - )?; + if add_roundrobin { + // Add round-robin repartitioning on top of the operator + // to increase parallelism. + child = add_roundrobin_on_top(child, target_partitions)?; + } + child = add_hash_on_top(child, exprs.to_vec(), target_partitions)?; + } + Distribution::UnspecifiedDistribution => { + if add_roundrobin { + // Add round-robin repartitioning on top of the operator + // to increase parallelism. + child = add_roundrobin_on_top(child, target_partitions)?; + } } - Distribution::UnspecifiedDistribution => {} }; // There is an ordering requirement of the operator: @@ -1362,17 +1354,10 @@ impl DistributionContext { fn update_children(mut self) -> Result { for child_context in self.children_nodes.iter_mut() { - child_context.distribution_connection = match child_context.plan.as_any() { - plan_any if plan_any.is::() => matches!( - plan_any - .downcast_ref::() - .unwrap() - .partitioning(), - Partitioning::RoundRobinBatch(_) | Partitioning::Hash(_, _) - ), - plan_any - if plan_any.is::() - || plan_any.is::() => + child_context.distribution_connection = match &child_context.plan { + plan if is_repartition(plan) + || is_coalesce_partitions(plan) + || is_sort_preserving_merge(plan) => { true } @@ -1915,7 +1900,7 @@ pub(crate) mod tests { let distribution_context = DistributionContext::new(plan); let mut config = ConfigOptions::new(); config.execution.target_partitions = target_partitions; - config.optimizer.enable_round_robin_repartition = false; + config.optimizer.enable_round_robin_repartition = true; config.optimizer.repartition_file_scans = false; config.optimizer.repartition_file_min_size = 1024; config.optimizer.prefer_existing_sort = prefer_existing_sort; @@ -3871,14 +3856,14 @@ pub(crate) mod tests { "RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", // Plan already has two partitions - "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e]", + "ParquetExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e]", ]; let expected_csv = [ "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", "RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", // Plan already has two partitions - "CsvExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], has_header=false", + "CsvExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], has_header=false", ]; assert_optimized!(expected_parquet, plan_parquet, true, false, 2, true, 10); diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 6b2fe24acf00..ba66dca55b35 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -38,11 +38,12 @@ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::ExecutionPlan; use arrow_schema::Schema; -use datafusion_common::internal_err; use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{internal_err, JoinSide}; use datafusion_common::{DataFusionError, JoinType}; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::sort_properties::SortProperties; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; /// The [`JoinSelection`] rule tries to modify a given plan so that it can /// accommodate infinite sources and optimize joins in the plan according to @@ -425,24 +426,97 @@ pub type PipelineFixerSubrule = dyn Fn( &ConfigOptions, ) -> Option>; -/// This subrule checks if we can replace a hash join with a symmetric hash -/// join when we are dealing with infinite inputs on both sides. This change -/// avoids pipeline breaking and preserves query runnability. If possible, -/// this subrule makes this replacement; otherwise, it has no effect. +/// Converts a hash join to a symmetric hash join in the case of infinite inputs on both sides. +/// +/// This subrule checks if a hash join can be replaced with a symmetric hash join when dealing +/// with unbounded (infinite) inputs on both sides. This replacement avoids pipeline breaking and +/// preserves query runnability. If the replacement is applicable, this subrule makes this change; +/// otherwise, it leaves the input unchanged. +/// +/// # Arguments +/// * `input` - The current state of the pipeline, including the execution plan. +/// * `config_options` - Configuration options that might affect the transformation logic. +/// +/// # Returns +/// An `Option` that contains the `Result` of the transformation. If the transformation is not applicable, +/// it returns `None`. If applicable, it returns `Some(Ok(...))` with the modified pipeline state, +/// or `Some(Err(...))` if an error occurs during the transformation. fn hash_join_convert_symmetric_subrule( mut input: PipelineStatePropagator, config_options: &ConfigOptions, ) -> Option> { + // Check if the current plan node is a HashJoinExec. if let Some(hash_join) = input.plan.as_any().downcast_ref::() { + // Determine if left and right children are unbounded. let ub_flags = input.children_unbounded(); let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); + // Update the unbounded flag of the input. input.unbounded = left_unbounded || right_unbounded; + // Process only if both left and right sides are unbounded. let result = if left_unbounded && right_unbounded { + // Determine the partition mode based on configuration. let mode = if config_options.optimizer.repartition_joins { StreamJoinPartitionMode::Partitioned } else { StreamJoinPartitionMode::SinglePartition }; + // A closure to determine the required sort order for each side of the join in the SymmetricHashJoinExec. + // This function checks if the columns involved in the filter have any specific ordering requirements. + // If the child nodes (left or right side of the join) already have a defined order and the columns used in the + // filter predicate are ordered, this function captures that ordering requirement. The identified order is then + // used in the SymmetricHashJoinExec to maintain bounded memory during join operations. + // However, if the child nodes do not have an inherent order, or if the filter columns are unordered, + // the function concludes that no specific order is required for the SymmetricHashJoinExec. This approach + // ensures that the symmetric hash join operation only imposes ordering constraints when necessary, + // based on the properties of the child nodes and the filter condition. + let determine_order = |side: JoinSide| -> Option> { + hash_join + .filter() + .map(|filter| { + filter.column_indices().iter().any( + |ColumnIndex { + index, + side: column_side, + }| { + // Skip if column side does not match the join side. + if *column_side != side { + return false; + } + // Retrieve equivalence properties and schema based on the side. + let (equivalence, schema) = match side { + JoinSide::Left => ( + hash_join.left().equivalence_properties(), + hash_join.left().schema(), + ), + JoinSide::Right => ( + hash_join.right().equivalence_properties(), + hash_join.right().schema(), + ), + }; + + let name = schema.field(*index).name(); + let col = Arc::new(Column::new(name, *index)) as _; + // Check if the column is ordered. + equivalence.get_expr_ordering(col).state + != SortProperties::Unordered + }, + ) + }) + .unwrap_or(false) + .then(|| { + match side { + JoinSide::Left => hash_join.left().output_ordering(), + JoinSide::Right => hash_join.right().output_ordering(), + } + .map(|p| p.to_vec()) + }) + .flatten() + }; + + // Determine the sort order for both left and right sides. + let left_order = determine_order(JoinSide::Left); + let right_order = determine_order(JoinSide::Right); + SymmetricHashJoinExec::try_new( hash_join.left().clone(), hash_join.right().clone(), @@ -450,6 +524,8 @@ fn hash_join_convert_symmetric_subrule( hash_join.filter().cloned(), hash_join.join_type(), hash_join.null_equals_null(), + left_order, + right_order, mode, ) .map(|exec| { diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 540f9a6a132b..9855247151b8 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -305,7 +305,7 @@ mod tests { AggregateMode::Partial, build_group_by(&schema.clone(), vec!["a".to_string()]), vec![], /* aggr_expr */ - vec![None], /* filter_expr */ + vec![], /* filter_expr */ source, /* input */ schema.clone(), /* input_schema */ )?; @@ -313,7 +313,7 @@ mod tests { AggregateMode::Final, build_group_by(&schema.clone(), vec!["a".to_string()]), vec![], /* aggr_expr */ - vec![None], /* filter_expr */ + vec![], /* filter_expr */ Arc::new(partial_agg), /* input */ schema.clone(), /* input_schema */ )?; @@ -355,7 +355,7 @@ mod tests { AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), vec![], /* aggr_expr */ - vec![None], /* filter_expr */ + vec![], /* filter_expr */ source, /* input */ schema.clone(), /* input_schema */ )?; @@ -396,7 +396,7 @@ mod tests { AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), vec![], /* aggr_expr */ - vec![None], /* filter_expr */ + vec![], /* filter_expr */ source, /* input */ schema.clone(), /* input_schema */ )?; @@ -437,7 +437,7 @@ mod tests { AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string(), "b".to_string()]), vec![], /* aggr_expr */ - vec![None], /* filter_expr */ + vec![], /* filter_expr */ source, /* input */ schema.clone(), /* input_schema */ )?; @@ -445,7 +445,7 @@ mod tests { AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), vec![], /* aggr_expr */ - vec![None], /* filter_expr */ + vec![], /* filter_expr */ Arc::new(group_by_agg), /* input */ schema.clone(), /* input_schema */ )?; @@ -487,7 +487,7 @@ mod tests { AggregateMode::Single, build_group_by(&schema.clone(), vec![]), vec![], /* aggr_expr */ - vec![None], /* filter_expr */ + vec![], /* filter_expr */ source, /* input */ schema.clone(), /* input_schema */ )?; @@ -549,13 +549,14 @@ mod tests { cast(expressions::lit(1u32), &schema, DataType::Int32)?, &schema, )?); + let agg = TestAggregate::new_count_star(); let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![], /* aggr_expr */ - vec![filter_expr], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![agg.count_expr()], /* aggr_expr */ + vec![filter_expr], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), @@ -565,7 +566,7 @@ mod tests { // TODO(msirek): open an issue for `filter_expr` of `AggregateExec` not printing out let expected = [ "LocalLimitExec: fetch=10", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[]", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]", "MemoryExec: partitions=1, partition_sizes=[1]", ]; let plan: Arc = Arc::new(limit_exec); @@ -588,7 +589,7 @@ mod tests { AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), vec![], /* aggr_expr */ - vec![None], /* filter_expr */ + vec![], /* filter_expr */ source, /* input */ schema.clone(), /* input_schema */ )?; diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index d237a3e8607e..34d1af85565a 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -795,6 +795,8 @@ fn try_swapping_with_sym_hash_join( new_filter, sym_join.join_type(), sym_join.null_equals_null(), + sym_join.right().output_ordering().map(|p| p.to_vec()), + sym_join.left().output_ordering().map(|p| p.to_vec()), sym_join.partition_mode(), )?))) } @@ -2048,6 +2050,8 @@ mod tests { )), &JoinType::Inner, true, + None, + None, StreamJoinPartitionMode::SinglePartition, )?); let projection: Arc = Arc::new(ProjectionExec::try_new( diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 06cfc7282468..aa0c26723767 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -45,12 +45,23 @@ use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarant use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use log::trace; -/// Interface to pass statistics (min/max/nulls) information to [`PruningPredicate`]. +/// A source of runtime statistical information to [`PruningPredicate`]s. /// -/// Returns statistics for containers / files as Arrow [`ArrayRef`], so the -/// evaluation happens once on a single `RecordBatch`, amortizing the overhead -/// of evaluating of the predicate. This is important when pruning 1000s of -/// containers which often happens in analytic systems. +/// # Supported Information +/// +/// 1. Minimum and maximum values for columns +/// +/// 2. Null counts for columns +/// +/// 3. Whether the values in a column are contained in a set of literals +/// +/// # Vectorized Interface +/// +/// Information for containers / files are returned as Arrow [`ArrayRef`], so +/// the evaluation happens once on a single `RecordBatch`, which amortizes the +/// overhead of evaluating the predicate. This is important when pruning 1000s +/// of containers which often happens in analytic systems that have 1000s of +/// potential files to consider. /// /// For example, for the following three files with a single column `a`: /// ```text @@ -83,8 +94,11 @@ pub trait PruningStatistics { /// Note: the returned array must contain [`Self::num_containers`] rows fn max_values(&self, column: &Column) -> Option; - /// Return the number of containers (e.g. row groups) being - /// pruned with these statistics (the number of rows in each returned array) + /// Return the number of containers (e.g. Row Groups) being pruned with + /// these statistics. + /// + /// This value corresponds to the size of the [`ArrayRef`] returned by + /// [`Self::min_values`], [`Self::max_values`], and [`Self::null_counts`]. fn num_containers(&self) -> usize; /// Return the number of null values for the named column as an @@ -95,13 +109,11 @@ pub trait PruningStatistics { /// Note: the returned array must contain [`Self::num_containers`] rows fn null_counts(&self, column: &Column) -> Option; - /// Returns an array where each row represents information known about - /// the `values` contained in a column. + /// Returns [`BooleanArray`] where each row represents information known + /// about specific literal `values` in a column. /// - /// This API is designed to be used along with [`LiteralGuarantee`] to prove - /// that predicates can not possibly evaluate to `true` and thus prune - /// containers. For example, Parquet Bloom Filters can prove that values are - /// not present. + /// For example, Parquet Bloom Filters implement this API to communicate + /// that `values` are known not to be present in a Row Group. /// /// The returned array has one row for each container, with the following /// meanings: @@ -120,28 +132,34 @@ pub trait PruningStatistics { ) -> Option; } -/// Evaluates filter expressions on statistics such as min/max values and null -/// counts, attempting to prove a "container" (e.g. Parquet Row Group) can be -/// skipped without reading the actual data, potentially leading to significant -/// performance improvements. +/// Used to prove that arbitrary predicates (boolean expression) can not +/// possibly evaluate to `true` given information about a column provided by +/// [`PruningStatistics`]. +/// +/// `PruningPredicate` analyzes filter expressions using statistics such as +/// min/max values and null counts, attempting to prove a "container" (e.g. +/// Parquet Row Group) can be skipped without reading the actual data, +/// potentially leading to significant performance improvements. +/// +/// For example, `PruningPredicate`s are used to prune Parquet Row Groups based +/// on the min/max values found in the Parquet metadata. If the +/// `PruningPredicate` can prove that the filter can never evaluate to `true` +/// for any row in the Row Group, the entire Row Group is skipped during query +/// execution. /// -/// For example, [`PruningPredicate`]s are used to prune Parquet Row Groups -/// based on the min/max values found in the Parquet metadata. If the -/// `PruningPredicate` can guarantee that no rows in the Row Group match the -/// filter, the entire Row Group is skipped during query execution. +/// The `PruningPredicate` API is designed to be general, so it can used for +/// pruning other types of containers (e.g. files) based on statistics that may +/// be known from external catalogs (e.g. Delta Lake) or other sources. /// -/// The `PruningPredicate` API is general, allowing it to be used for pruning -/// other types of containers (e.g. files) based on statistics that may be -/// known from external catalogs (e.g. Delta Lake) or other sources. Thus it -/// supports: +/// It currently supports: /// -/// 1. Arbitrary expressions expressions (including user defined functions) +/// 1. Arbitrary expressions (including user defined functions) /// /// 2. Vectorized evaluation (provide more than one set of statistics at a time) /// so it is suitable for pruning 1000s of containers. /// -/// 3. Anything that implements the [`PruningStatistics`] trait, not just -/// Parquet metadata. +/// 3. Any source of information that implements the [`PruningStatistics`] trait +/// (not just Parquet metadata). /// /// # Example /// @@ -154,7 +172,8 @@ pub trait PruningStatistics { /// C: {x_min = 5, x_max = 8} /// ``` /// -/// Applying the `PruningPredicate` will concludes that `A` can be pruned: +/// `PruningPredicate` will conclude that the rows in container `A` can never +/// be true (as the maximum value is only `4`), so it can be pruned: /// /// ```text /// A: false (no rows could possibly match x = 5) @@ -295,6 +314,11 @@ impl PruningPredicate { &self.predicate_expr } + /// Returns a reference to the literal guarantees + pub fn literal_guarantees(&self) -> &[LiteralGuarantee] { + &self.literal_guarantees + } + /// Returns true if this pruning predicate can not prune anything. /// /// This happens if the predicate is a literal `true` and @@ -902,11 +926,17 @@ fn build_is_null_column_expr( } } +/// The maximum number of entries in an `InList` that might be rewritten into +/// an OR chain +const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20; + /// Translate logical filter expression into pruning predicate /// expression that will evaluate to FALSE if it can be determined no /// rows between the min/max values could pass the predicates. /// /// Returns the pruning predicate as an [`PhysicalExpr`] +/// +/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will be rewritten to TRUE fn build_predicate_expression( expr: &Arc, schema: &Schema, @@ -936,7 +966,9 @@ fn build_predicate_expression( } } if let Some(in_list) = expr_any.downcast_ref::() { - if !in_list.list().is_empty() && in_list.list().len() < 20 { + if !in_list.list().is_empty() + && in_list.list().len() <= MAX_LIST_VALUE_SIZE_REWRITE + { let eq_op = if in_list.negated() { Operator::NotEq } else { @@ -1934,6 +1966,68 @@ mod tests { Ok(()) } + #[test] + fn row_group_predicate_between() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ]); + + // test c1 BETWEEN 1 AND 5 + let expr1 = col("c1").between(lit(1), lit(5)); + + // test 1 <= c1 <= 5 + let expr2 = col("c1").gt_eq(lit(1)).and(col("c1").lt_eq(lit(5))); + + let predicate_expr1 = + test_build_predicate_expression(&expr1, &schema, &mut RequiredColumns::new()); + + let predicate_expr2 = + test_build_predicate_expression(&expr2, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr1.to_string(), predicate_expr2.to_string()); + + Ok(()) + } + + #[test] + fn row_group_predicate_between_with_in_list() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ]); + // test c1 in(1, 2) + let expr1 = col("c1").in_list(vec![lit(1), lit(2)], false); + + // test c2 BETWEEN 4 AND 5 + let expr2 = col("c2").between(lit(4), lit(5)); + + // test c1 in(1, 2) and c2 BETWEEN 4 AND 5 + let expr3 = expr1.and(expr2); + + let expected_expr = "(c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1) AND c2_max@2 >= 4 AND c2_min@3 <= 5"; + let predicate_expr = + test_build_predicate_expression(&expr3, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_in_list_to_many_values() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + // test c1 in(1..21) + // in pruning.rs has MAX_LIST_VALUE_SIZE_REWRITE = 20, more than this value will be rewrite + // always true + let expr = col("c1").in_list((1..=21).map(lit).collect(), false); + + let expected_expr = "true"; + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + #[test] fn row_group_predicate_cast() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); @@ -2017,54 +2111,52 @@ mod tests { DataType::Decimal128(9, 2), true, )])); - // s1 > 5 - let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))); - let expr = logical2physical(&expr, &schema); - // If the data is written by spark, the physical data type is INT32 in the parquet - // So we use the INT32 type of statistic. - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_i32( - vec![Some(0), Some(4), None, Some(3)], // min - vec![Some(5), Some(6), Some(4), None], // max + + prune_with_expr( + // s1 > 5 + col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))), + &schema, + // If the data is written by spark, the physical data type is INT32 in the parquet + // So we use the INT32 type of statistic. + &TestStatistics::new().with( + "s1", + ContainerStats::new_i32( + vec![Some(0), Some(4), None, Some(3)], // min + vec![Some(5), Some(6), Some(4), None], // max + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); - // with cast column to other type - let expr = cast(col("s1"), DataType::Decimal128(14, 3)) - .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))); - let expr = logical2physical(&expr, &schema); - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_i32( - vec![Some(0), Some(4), None, Some(3)], // min - vec![Some(5), Some(6), Some(4), None], // max + prune_with_expr( + // with cast column to other type + cast(col("s1"), DataType::Decimal128(14, 3)) + .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))), + &schema, + &TestStatistics::new().with( + "s1", + ContainerStats::new_i32( + vec![Some(0), Some(4), None, Some(3)], // min + vec![Some(5), Some(6), Some(4), None], // max + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); - // with try cast column to other type - let expr = try_cast(col("s1"), DataType::Decimal128(14, 3)) - .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))); - let expr = logical2physical(&expr, &schema); - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_i32( - vec![Some(0), Some(4), None, Some(3)], // min - vec![Some(5), Some(6), Some(4), None], // max + prune_with_expr( + // with try cast column to other type + try_cast(col("s1"), DataType::Decimal128(14, 3)) + .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))), + &schema, + &TestStatistics::new().with( + "s1", + ContainerStats::new_i32( + vec![Some(0), Some(4), None, Some(3)], // min + vec![Some(5), Some(6), Some(4), None], // max + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); // decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( @@ -2072,22 +2164,21 @@ mod tests { DataType::Decimal128(18, 2), true, )])); - // s1 > 5 - let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 18, 2))); - let expr = logical2physical(&expr, &schema); - // If the data is written by spark, the physical data type is INT64 in the parquet - // So we use the INT32 type of statistic. - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_i64( - vec![Some(0), Some(4), None, Some(3)], // min - vec![Some(5), Some(6), Some(4), None], // max + prune_with_expr( + // s1 > 5 + col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 18, 2))), + &schema, + // If the data is written by spark, the physical data type is INT64 in the parquet + // So we use the INT32 type of statistic. + &TestStatistics::new().with( + "s1", + ContainerStats::new_i64( + vec![Some(0), Some(4), None, Some(3)], // min + vec![Some(5), Some(6), Some(4), None], // max + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); // decimal(23,2) let schema = Arc::new(Schema::new(vec![Field::new( @@ -2095,22 +2186,22 @@ mod tests { DataType::Decimal128(23, 2), true, )])); - // s1 > 5 - let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 23, 2))); - let expr = logical2physical(&expr, &schema); - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_decimal128( - vec![Some(0), Some(400), None, Some(300)], // min - vec![Some(500), Some(600), Some(400), None], // max - 23, - 2, + + prune_with_expr( + // s1 > 5 + col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 23, 2))), + &schema, + &TestStatistics::new().with( + "s1", + ContainerStats::new_decimal128( + vec![Some(0), Some(400), None, Some(300)], // min + vec![Some(500), Some(600), Some(400), None], // max + 23, + 2, + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); } #[test] @@ -2120,10 +2211,6 @@ mod tests { Field::new("s2", DataType::Int32, true), ])); - // Prune using s2 > 5 - let expr = col("s2").gt(lit(5)); - let expr = logical2physical(&expr, &schema); - let statistics = TestStatistics::new().with( "s2", ContainerStats::new_i32( @@ -2131,53 +2218,50 @@ mod tests { vec![Some(5), Some(6), None, None], // max ), ); + prune_with_expr( + // Prune using s2 > 5 + col("s2").gt(lit(5)), + &schema, + &statistics, + // s2 [0, 5] ==> no rows should pass + // s2 [4, 6] ==> some rows could pass + // No stats for s2 ==> some rows could pass + // s2 [3, None] (null max) ==> some rows could pass + &[false, true, true, true], + ); - // s2 [0, 5] ==> no rows should pass - // s2 [4, 6] ==> some rows could pass - // No stats for s2 ==> some rows could pass - // s2 [3, None] (null max) ==> some rows could pass - - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, true, true]; - assert_eq!(result, expected); - - // filter with cast - let expr = cast(col("s2"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(5)))); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, true, true]; - assert_eq!(result, expected); + prune_with_expr( + // filter with cast + cast(col("s2"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(5)))), + &schema, + &statistics, + &[false, true, true, true], + ); } #[test] fn prune_not_eq_data() { let schema = Arc::new(Schema::new(vec![Field::new("s1", DataType::Utf8, true)])); - // Prune using s2 != 'M' - let expr = col("s1").not_eq(lit("M")); - let expr = logical2physical(&expr, &schema); - - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_utf8( - vec![Some("A"), Some("A"), Some("N"), Some("M"), None, Some("A")], // min - vec![Some("Z"), Some("L"), Some("Z"), Some("M"), None, None], // max + prune_with_expr( + // Prune using s2 != 'M' + col("s1").not_eq(lit("M")), + &schema, + &TestStatistics::new().with( + "s1", + ContainerStats::new_utf8( + vec![Some("A"), Some("A"), Some("N"), Some("M"), None, Some("A")], // min + vec![Some("Z"), Some("L"), Some("Z"), Some("M"), None, None], // max + ), ), + // s1 [A, Z] ==> might have values that pass predicate + // s1 [A, L] ==> all rows pass the predicate + // s1 [N, Z] ==> all rows pass the predicate + // s1 [M, M] ==> all rows do not pass the predicate + // No stats for s2 ==> some rows could pass + // s2 [3, None] (null max) ==> some rows could pass + &[true, true, true, false, true, true], ); - - // s1 [A, Z] ==> might have values that pass predicate - // s1 [A, L] ==> all rows pass the predicate - // s1 [N, Z] ==> all rows pass the predicate - // s1 [M, M] ==> all rows do not pass the predicate - // No stats for s2 ==> some rows could pass - // s2 [3, None] (null max) ==> some rows could pass - - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![true, true, true, false, true, true]; - assert_eq!(result, expected); } /// Creates setup for boolean chunk pruning @@ -2216,69 +2300,75 @@ mod tests { fn prune_bool_const_expr() { let (schema, statistics, _, _) = bool_setup(); - // true - let expr = lit(true); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, vec![true, true, true, true, true]); + prune_with_expr( + // true + lit(true), + &schema, + &statistics, + &[true, true, true, true, true], + ); - // false - // constant literals that do NOT refer to any columns are currently not evaluated at all, hence the result is - // "all true" - let expr = lit(false); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, vec![true, true, true, true, true]); + prune_with_expr( + // false + // constant literals that do NOT refer to any columns are currently not evaluated at all, hence the result is + // "all true" + lit(false), + &schema, + &statistics, + &[true, true, true, true, true], + ); } #[test] fn prune_bool_column() { let (schema, statistics, expected_true, _) = bool_setup(); - // b1 - let expr = col("b1"); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_true); + prune_with_expr( + // b1 + col("b1"), + &schema, + &statistics, + &expected_true, + ); } #[test] fn prune_bool_not_column() { let (schema, statistics, _, expected_false) = bool_setup(); - // !b1 - let expr = col("b1").not(); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_false); + prune_with_expr( + // !b1 + col("b1").not(), + &schema, + &statistics, + &expected_false, + ); } #[test] fn prune_bool_column_eq_true() { let (schema, statistics, expected_true, _) = bool_setup(); - // b1 = true - let expr = col("b1").eq(lit(true)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_true); + prune_with_expr( + // b1 = true + col("b1").eq(lit(true)), + &schema, + &statistics, + &expected_true, + ); } #[test] fn prune_bool_not_column_eq_true() { let (schema, statistics, _, expected_false) = bool_setup(); - // !b1 = true - let expr = col("b1").not().eq(lit(true)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_false); + prune_with_expr( + // !b1 = true + col("b1").not().eq(lit(true)), + &schema, + &statistics, + &expected_false, + ); } /// Creates a setup for chunk pruning, modeling a int32 column "i" @@ -2313,21 +2403,18 @@ mod tests { // i [-11, -1] ==> no rows can pass (not keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> unknown (must keep) - let expected_ret = vec![true, true, false, true, true]; + let expected_ret = &[true, true, false, true, true]; // i > 0 - let expr = col("i").gt(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr(col("i").gt(lit(0)), &schema, &statistics, expected_ret); // -i < 0 - let expr = Expr::Negative(Box::new(col("i"))).lt(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + Expr::Negative(Box::new(col("i"))).lt(lit(0)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2340,21 +2427,23 @@ mod tests { // i [-11, -1] ==> all rows must pass (must keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> no rows can pass (not keep) - let expected_ret = vec![true, false, true, true, false]; + let expected_ret = &[true, false, true, true, false]; - // i <= 0 - let expr = col("i").lt_eq(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i <= 0 + col("i").lt_eq(lit(0)), + &schema, + &statistics, + expected_ret, + ); - // -i >= 0 - let expr = Expr::Negative(Box::new(col("i"))).gt_eq(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // -i >= 0 + Expr::Negative(Box::new(col("i"))).gt_eq(lit(0)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2367,37 +2456,39 @@ mod tests { // i [-11, -1] ==> no rows could pass in theory (conservatively keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> no rows can pass (conservatively keep) - let expected_ret = vec![true, true, true, true, true]; + let expected_ret = &[true, true, true, true, true]; - // cast(i as utf8) <= 0 - let expr = cast(col("i"), DataType::Utf8).lt_eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // cast(i as utf8) <= 0 + cast(col("i"), DataType::Utf8).lt_eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); - // try_cast(i as utf8) <= 0 - let expr = try_cast(col("i"), DataType::Utf8).lt_eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // try_cast(i as utf8) <= 0 + try_cast(col("i"), DataType::Utf8).lt_eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); - // cast(-i as utf8) >= 0 - let expr = - cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // cast(-i as utf8) >= 0 + cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); - // try_cast(-i as utf8) >= 0 - let expr = - try_cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // try_cast(-i as utf8) >= 0 + try_cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2410,14 +2501,15 @@ mod tests { // i [-11, -1] ==> no rows can pass (not keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> no rows can pass (not keep) - let expected_ret = vec![true, false, false, true, false]; + let expected_ret = &[true, false, false, true, false]; - // i = 0 - let expr = col("i").eq(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i = 0 + col("i").eq(lit(0)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2430,19 +2522,21 @@ mod tests { // i [-11, -1] ==> no rows can pass (not keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> no rows can pass (not keep) - let expected_ret = vec![true, false, false, true, false]; + let expected_ret = &[true, false, false, true, false]; - let expr = cast(col("i"), DataType::Int64).eq(lit(0i64)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + cast(col("i"), DataType::Int64).eq(lit(0i64)), + &schema, + &statistics, + expected_ret, + ); - let expr = try_cast(col("i"), DataType::Int64).eq(lit(0i64)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + try_cast(col("i"), DataType::Int64).eq(lit(0i64)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2458,13 +2552,14 @@ mod tests { // i [-11, -1] ==> no rows can pass (could keep) // i [NULL, NULL] ==> unknown (keep) // i [1, NULL] ==> no rows can pass (could keep) - let expected_ret = vec![true, true, true, true, true]; + let expected_ret = &[true, true, true, true, true]; - let expr = cast(col("i"), DataType::Utf8).eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + cast(col("i"), DataType::Utf8).eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2477,21 +2572,23 @@ mod tests { // i [-11, -1] ==> no rows can pass (not keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> all rows must pass (must keep) - let expected_ret = vec![true, true, false, true, true]; + let expected_ret = &[true, true, false, true, true]; - // i > -1 - let expr = col("i").gt(lit(-1)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i > -1 + col("i").gt(lit(-1)), + &schema, + &statistics, + expected_ret, + ); - // -i < 1 - let expr = Expr::Negative(Box::new(col("i"))).lt(lit(1)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // -i < 1 + Expr::Negative(Box::new(col("i"))).lt(lit(1)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2500,14 +2597,15 @@ mod tests { // Expression "i IS NULL" when there are no null statistics, // should all be kept - let expected_ret = vec![true, true, true, true, true]; + let expected_ret = &[true, true, true, true, true]; - // i IS NULL, no null statistics - let expr = col("i").is_null(); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i IS NULL, no null statistics + col("i").is_null(), + &schema, + &statistics, + expected_ret, + ); // provide null counts for each column let statistics = statistics.with_null_counts( @@ -2521,51 +2619,55 @@ mod tests { ], ); - let expected_ret = vec![false, true, true, true, false]; + let expected_ret = &[false, true, true, true, false]; - // i IS NULL, with actual null statistcs - let expr = col("i").is_null(); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i IS NULL, with actual null statistcs + col("i").is_null(), + &schema, + &statistics, + expected_ret, + ); } #[test] fn prune_cast_column_scalar() { // The data type of column i is INT32 let (schema, statistics) = int32_setup(); - let expected_ret = vec![true, true, false, true, true]; + let expected_ret = &[true, true, false, true, true]; - // i > int64(0) - let expr = col("i").gt(cast(lit(ScalarValue::Int64(Some(0))), DataType::Int32)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i > int64(0) + col("i").gt(cast(lit(ScalarValue::Int64(Some(0))), DataType::Int32)), + &schema, + &statistics, + expected_ret, + ); - // cast(i as int64) > int64(0) - let expr = cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // cast(i as int64) > int64(0) + cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))), + &schema, + &statistics, + expected_ret, + ); - // try_cast(i as int64) > int64(0) - let expr = - try_cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // try_cast(i as int64) > int64(0) + try_cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))), + &schema, + &statistics, + expected_ret, + ); - // `-cast(i as int64) < 0` convert to `cast(i as int64) > -0` - let expr = Expr::Negative(Box::new(cast(col("i"), DataType::Int64))) - .lt(lit(ScalarValue::Int64(Some(0)))); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // `-cast(i as int64) < 0` convert to `cast(i as int64) > -0` + Expr::Negative(Box::new(cast(col("i"), DataType::Int64))) + .lt(lit(ScalarValue::Int64(Some(0)))), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2721,7 +2823,7 @@ mod tests { &schema, &statistics, // rule out containers ('false) where we know foo is not present - vec![true, false, true, true, false, true, true, false, true], + &[true, false, true, true, false, true, true, false, true], ); // s1 = 'bar' @@ -2730,7 +2832,7 @@ mod tests { &schema, &statistics, // rule out containers where we know bar is not present - vec![true, true, true, false, false, false, true, true, true], + &[true, true, true, false, false, false, true, true, true], ); // s1 = 'baz' (unknown value) @@ -2739,7 +2841,7 @@ mod tests { &schema, &statistics, // can't rule out anything - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 = 'foo' AND s1 = 'bar' @@ -2750,7 +2852,7 @@ mod tests { // logically this predicate can't possibly be true (the column can't // take on both values) but we could rule it out if the stats tell // us that both values are not present - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 = 'foo' OR s1 = 'bar' @@ -2759,7 +2861,7 @@ mod tests { &schema, &statistics, // can rule out containers that we know contain neither foo nor bar - vec![true, true, true, true, true, true, false, false, false], + &[true, true, true, true, true, true, false, false, false], ); // s1 = 'foo' OR s1 = 'baz' @@ -2768,7 +2870,7 @@ mod tests { &schema, &statistics, // can't rule out anything container - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 = 'foo' OR s1 = 'bar' OR s1 = 'baz' @@ -2781,7 +2883,7 @@ mod tests { &statistics, // can rule out any containers based on knowledge of s1 and `foo`, // `bar` and (`foo`, `bar`) - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 != foo @@ -2790,7 +2892,7 @@ mod tests { &schema, &statistics, // rule out containers we know for sure only contain foo - vec![false, true, true, false, true, true, false, true, true], + &[false, true, true, false, true, true, false, true, true], ); // s1 != bar @@ -2799,7 +2901,7 @@ mod tests { &schema, &statistics, // rule out when we know for sure s1 has the value bar - vec![false, false, false, true, true, true, true, true, true], + &[false, false, false, true, true, true, true, true, true], ); // s1 != foo AND s1 != bar @@ -2810,7 +2912,7 @@ mod tests { &schema, &statistics, // can rule out any container where we know s1 does not have either 'foo' or 'bar' - vec![true, true, true, false, false, false, true, true, true], + &[true, true, true, false, false, false, true, true, true], ); // s1 != foo AND s1 != bar AND s1 != baz @@ -2822,7 +2924,7 @@ mod tests { &schema, &statistics, // can't rule out any container based on knowledge of s1,s2 - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 != foo OR s1 != bar @@ -2833,7 +2935,7 @@ mod tests { &schema, &statistics, // cant' rule out anything based on contains information - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 != foo OR s1 != bar OR s1 != baz @@ -2845,7 +2947,7 @@ mod tests { &schema, &statistics, // cant' rule out anything based on contains information - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); } @@ -2907,7 +3009,7 @@ mod tests { &schema, &statistics, // rule out containers where we know s1 is not present - vec![true, false, true, true, false, true, true, false, true], + &[true, false, true, true, false, true, true, false, true], ); // s1 = 'foo' OR s2 = 'bar' @@ -2917,7 +3019,7 @@ mod tests { &schema, &statistics, // can't rule out any container (would need to prove that s1 != foo AND s2 != bar) - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 = 'foo' AND s2 != 'bar' @@ -2928,7 +3030,7 @@ mod tests { // can only rule out container where we know either: // 1. s1 doesn't have the value 'foo` or // 2. s2 has only the value of 'bar' - vec![false, false, false, true, false, true, true, false, true], + &[false, false, false, true, false, true, true, false, true], ); // s1 != 'foo' AND s2 != 'bar' @@ -2941,7 +3043,7 @@ mod tests { // Can rule out any container where we know either // 1. s1 has only the value 'foo' // 2. s2 has only the value 'bar' - vec![false, false, false, false, true, true, false, true, true], + &[false, false, false, false, true, true, false, true, true], ); // s1 != 'foo' AND (s2 = 'bar' OR s2 = 'baz') @@ -2953,7 +3055,7 @@ mod tests { &statistics, // Can rule out any container where we know s1 has only the value // 'foo'. Can't use knowledge of s2 and bar to rule out anything - vec![false, true, true, false, true, true, false, true, true], + &[false, true, true, false, true, true, false, true, true], ); // s1 like '%foo%bar%' @@ -2962,7 +3064,7 @@ mod tests { &schema, &statistics, // cant rule out anything with information we know - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 like '%foo%bar%' AND s2 = 'bar' @@ -2973,7 +3075,7 @@ mod tests { &schema, &statistics, // can rule out any container where we know s2 does not have the value 'bar' - vec![true, true, true, false, false, false, true, true, true], + &[true, true, true, false, false, false, true, true, true], ); // s1 like '%foo%bar%' OR s2 = 'bar' @@ -2983,7 +3085,7 @@ mod tests { &statistics, // can't rule out anything (we would have to prove that both the // like and the equality must be false) - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); } @@ -3055,7 +3157,7 @@ mod tests { // 1. 0 is outside the min/max range of i // 1. s does not contain foo // (range is false, and contained is false) - vec![true, false, true, false, false, false, true, false, true], + &[true, false, true, false, false, false, true, false, true], ); // i = 0 and s != 'foo' @@ -3066,7 +3168,7 @@ mod tests { // Can rule out containers where either: // 1. 0 is outside the min/max range of i // 2. s only contains foo - vec![false, false, false, true, false, true, true, false, true], + &[false, false, false, true, false, true, true, false, true], ); // i = 0 OR s = 'foo' @@ -3076,7 +3178,7 @@ mod tests { &statistics, // in theory could rule out containers if we had min/max values for // s as well. But in this case we don't so we can't rule out anything - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); } @@ -3091,7 +3193,7 @@ mod tests { expr: Expr, schema: &SchemaRef, statistics: &TestStatistics, - expected: Vec, + expected: &[bool], ) { println!("Pruning with expr: {}", expr); let expr = logical2physical(&expr, schema); @@ -3112,6 +3214,6 @@ mod tests { fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { let df_schema = schema.clone().to_dfschema().unwrap(); let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() + create_physical_expr(expr, &df_schema, &execution_props).unwrap() } } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index d696c55a8c13..98390ac271d0 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -401,13 +401,10 @@ pub trait PhysicalPlanner: Send + Sync { /// `expr`: the expression to convert /// /// `input_dfschema`: the logical plan schema for evaluating `expr` - /// - /// `input_schema`: the physical schema for evaluating `expr` fn create_physical_expr( &self, expr: &Expr, input_dfschema: &DFSchema, - input_schema: &Schema, session_state: &SessionState, ) -> Result>; } @@ -467,21 +464,13 @@ impl PhysicalPlanner for DefaultPhysicalPlanner { /// `e`: the expression to convert /// /// `input_dfschema`: the logical plan schema for evaluating `e` - /// - /// `input_schema`: the physical schema for evaluating `e` fn create_physical_expr( &self, expr: &Expr, input_dfschema: &DFSchema, - input_schema: &Schema, session_state: &SessionState, ) -> Result> { - create_physical_expr( - expr, - input_dfschema, - input_schema, - session_state.execution_props(), - ) + create_physical_expr(expr, input_dfschema, session_state.execution_props()) } } @@ -654,7 +643,6 @@ impl DefaultPhysicalPlanner { self.create_physical_expr( expr, schema, - &exec_schema, session_state, ) }) @@ -694,7 +682,6 @@ impl DefaultPhysicalPlanner { self.create_physical_expr( e, input.schema(), - &input_exec.schema(), session_state, ) }) @@ -884,7 +871,6 @@ impl DefaultPhysicalPlanner { self.create_physical_expr( e, input_schema, - &input_exec.schema(), session_state, ), physical_name, @@ -899,13 +885,11 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Filter(filter) => { let physical_input = self.create_initial_plan(&filter.input, session_state).await?; - let input_schema = physical_input.as_ref().schema(); let input_dfschema = filter.input.schema(); let runtime_expr = self.create_physical_expr( &filter.predicate, input_dfschema, - &input_schema, session_state, )?; let selectivity = session_state.config().options().optimizer.default_filter_selectivity; @@ -922,7 +906,6 @@ impl DefaultPhysicalPlanner { partitioning_scheme, }) => { let physical_input = self.create_initial_plan(input, session_state).await?; - let input_schema = physical_input.schema(); let input_dfschema = input.as_ref().schema(); let physical_partitioning = match partitioning_scheme { LogicalPartitioning::RoundRobinBatch(n) => { @@ -935,7 +918,6 @@ impl DefaultPhysicalPlanner { self.create_physical_expr( e, input_dfschema, - &input_schema, session_state, ) }) @@ -953,14 +935,12 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Sort(Sort { expr, input, fetch, .. }) => { let physical_input = self.create_initial_plan(input, session_state).await?; - let input_schema = physical_input.as_ref().schema(); let input_dfschema = input.as_ref().schema(); let sort_expr = expr .iter() .map(|e| create_physical_sort_expr( e, input_dfschema, - &input_schema, session_state.execution_props(), )) .collect::>>()?; @@ -1107,7 +1087,6 @@ impl DefaultPhysicalPlanner { let filter_expr = create_physical_expr( expr, &filter_df_schema, - &filter_schema, session_state.execution_props(), )?; let column_indices = join_utils::JoinFilter::build_column_indices(left_field_indices, right_field_indices); @@ -1348,12 +1327,7 @@ impl DefaultPhysicalPlanner { ) } expr => Ok(PhysicalGroupBy::new_single(vec![tuple_err(( - self.create_physical_expr( - expr, - input_dfschema, - input_schema, - session_state, - ), + self.create_physical_expr(expr, input_dfschema, session_state), physical_name(expr), ))?])), } @@ -1363,12 +1337,7 @@ impl DefaultPhysicalPlanner { .iter() .map(|e| { tuple_err(( - self.create_physical_expr( - e, - input_dfschema, - input_schema, - session_state, - ), + self.create_physical_expr(e, input_dfschema, session_state), physical_name(e), )) }) @@ -1406,7 +1375,6 @@ fn merge_grouping_set_physical_expr( grouping_set_expr.push(get_physical_expr_pair( expr, input_dfschema, - input_schema, session_state, )?); @@ -1461,12 +1429,7 @@ fn create_cube_physical_expr( session_state, )?); - all_exprs.push(get_physical_expr_pair( - expr, - input_dfschema, - input_schema, - session_state, - )?) + all_exprs.push(get_physical_expr_pair(expr, input_dfschema, session_state)?) } let mut groups: Vec> = Vec::with_capacity(num_groups); @@ -1509,12 +1472,7 @@ fn create_rollup_physical_expr( session_state, )?); - all_exprs.push(get_physical_expr_pair( - expr, - input_dfschema, - input_schema, - session_state, - )?) + all_exprs.push(get_physical_expr_pair(expr, input_dfschema, session_state)?) } for total in 0..=num_of_exprs { @@ -1541,12 +1499,8 @@ fn get_null_physical_expr_pair( input_schema: &Schema, session_state: &SessionState, ) -> Result<(Arc, String)> { - let physical_expr = create_physical_expr( - expr, - input_dfschema, - input_schema, - session_state.execution_props(), - )?; + let physical_expr = + create_physical_expr(expr, input_dfschema, session_state.execution_props())?; let physical_name = physical_name(&expr.clone())?; let data_type = physical_expr.data_type(input_schema)?; @@ -1559,15 +1513,10 @@ fn get_null_physical_expr_pair( fn get_physical_expr_pair( expr: &Expr, input_dfschema: &DFSchema, - input_schema: &Schema, session_state: &SessionState, ) -> Result<(Arc, String)> { - let physical_expr = create_physical_expr( - expr, - input_dfschema, - input_schema, - session_state.execution_props(), - )?; + let physical_expr = + create_physical_expr(expr, input_dfschema, session_state.execution_props())?; let physical_name = physical_name(expr)?; Ok((physical_expr, physical_name)) } @@ -1611,35 +1560,16 @@ pub fn create_window_expr_with_name( }) => { let args = args .iter() - .map(|e| { - create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - execution_props, - ) - }) + .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) .collect::>>()?; let partition_by = partition_by .iter() - .map(|e| { - create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - execution_props, - ) - }) + .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) .collect::>>()?; let order_by = order_by .iter() .map(|e| { - create_physical_sort_expr( - e, - logical_input_schema, - physical_input_schema, - execution_props, - ) + create_physical_sort_expr(e, logical_input_schema, execution_props) }) .collect::>>()?; if !is_window_valid(window_frame) { @@ -1711,20 +1641,12 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( }) => { let args = args .iter() - .map(|e| { - create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - execution_props, - ) - }) + .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) .collect::>>()?; let filter = match filter { Some(e) => Some(create_physical_expr( e, logical_input_schema, - physical_input_schema, execution_props, )?), None => None, @@ -1736,7 +1658,6 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( create_physical_sort_expr( expr, logical_input_schema, - physical_input_schema, execution_props, ) }) @@ -1804,7 +1725,6 @@ pub fn create_aggregate_expr_and_maybe_filter( pub fn create_physical_sort_expr( e: &Expr, input_dfschema: &DFSchema, - input_schema: &Schema, execution_props: &ExecutionProps, ) -> Result { if let Expr::Sort(expr::Sort { @@ -1814,12 +1734,7 @@ pub fn create_physical_sort_expr( }) = e { Ok(PhysicalSortExpr { - expr: create_physical_expr( - expr, - input_dfschema, - input_schema, - execution_props, - )?, + expr: create_physical_expr(expr, input_dfschema, execution_props)?, options: SortOptions { descending: !asc, nulls_first: *nulls_first, @@ -2180,7 +2095,6 @@ mod tests { let expr = planner.create_physical_expr( &col("a").not(), &dfschema, - &schema, &make_session_state(), )?; let expected = expressions::not(expressions::col("a", &schema)?)?; diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 336a6804637a..1047c3dd4e48 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -166,12 +166,8 @@ impl TestParquetFile { if let Some(filter) = maybe_filter { let simplifier = ExprSimplifier::new(context); let filter = simplifier.coerce(filter, df_schema.clone()).unwrap(); - let physical_filter_expr = create_physical_expr( - &filter, - &df_schema, - self.schema.as_ref(), - &ExecutionProps::default(), - )?; + let physical_filter_expr = + create_physical_expr(&filter, &df_schema, &ExecutionProps::default())?; let parquet_exec = Arc::new(ParquetExec::new( scan_config, Some(physical_filter_expr.clone()), diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 943f7fdbf4ac..672498a9f84e 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -63,6 +63,7 @@ enum Scenario { Timestamps, Dates, Int32, + Int32Range, Float64, Decimal, DecimalLargePrecision, @@ -113,12 +114,24 @@ impl TestOutput { self.metric_value("predicate_evaluation_errors") } - /// The number of times the pruning predicate evaluation errors + /// The number of row_groups pruned by bloom filter + fn row_groups_pruned_bloom_filter(&self) -> Option { + self.metric_value("row_groups_pruned_bloom_filter") + } + + /// The number of row_groups pruned by statistics + fn row_groups_pruned_statistics(&self) -> Option { + self.metric_value("row_groups_pruned_statistics") + } + + /// The number of row_groups pruned fn row_groups_pruned(&self) -> Option { - self.metric_value("row_groups_pruned") + self.row_groups_pruned_bloom_filter() + .zip(self.row_groups_pruned_statistics()) + .map(|(a, b)| a + b) } - /// The number of times the pruning predicate evaluation errors + /// The number of row pages pruned fn row_pages_pruned(&self) -> Option { self.metric_value("page_index_rows_filtered") } @@ -145,7 +158,11 @@ impl ContextWithParquet { mut config: SessionConfig, ) -> Self { let file = match unit { - Unit::RowGroup => make_test_file_rg(scenario).await, + Unit::RowGroup => { + let config = config.options_mut(); + config.execution.parquet.bloom_filter_enabled = true; + make_test_file_rg(scenario).await + } Unit::Page => { let config = config.options_mut(); config.execution.parquet.enable_page_index = true; @@ -360,6 +377,13 @@ fn make_int32_batch(start: i32, end: i32) -> RecordBatch { RecordBatch::try_new(schema, vec![array.clone()]).unwrap() } +fn make_int32_range(start: i32, end: i32) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let v = vec![start, end]; + let array = Arc::new(Int32Array::from(v)) as ArrayRef; + RecordBatch::try_new(schema, vec![array.clone()]).unwrap() +} + /// Return record batch with f64 vector /// /// Columns are named @@ -508,6 +532,9 @@ fn create_data_batch(scenario: Scenario) -> Vec { make_int32_batch(5, 10), ] } + Scenario::Int32Range => { + vec![make_int32_range(0, 10), make_int32_range(200000, 300000)] + } Scenario::Float64 => { vec![ make_f64_batch(vec![-5.0, -4.0, -3.0, -2.0, -1.0]), @@ -565,6 +592,7 @@ async fn make_test_file_rg(scenario: Scenario) -> NamedTempFile { let props = WriterProperties::builder() .set_max_row_group_size(5) + .set_bloom_filter_enabled(true) .build(); let batches = create_data_batch(scenario); diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 23a56bc821d4..d182986ebbdc 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -67,8 +67,7 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { let df_schema = schema.clone().to_dfschema().unwrap(); let execution_props = ExecutionProps::new(); - let predicate = - create_physical_expr(&filter, &df_schema, &schema, &execution_props).unwrap(); + let predicate = create_physical_expr(&filter, &df_schema, &execution_props).unwrap(); let parquet_exec = ParquetExec::new( FileScanConfig { diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 7b5470fe350a..2bc5bd3f1ca7 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -20,6 +20,7 @@ //! expected. use datafusion::prelude::SessionConfig; use datafusion_common::ScalarValue; +use itertools::Itertools; use crate::parquet::Unit::RowGroup; use crate::parquet::{ContextWithParquet, Scenario}; @@ -48,6 +49,38 @@ async fn test_prune( ); } +/// check row group pruning by bloom filter and statistics independently +async fn test_prune_verbose( + case_data_type: Scenario, + sql: &str, + expected_errors: Option, + expected_row_group_pruned_sbbf: Option, + expected_row_group_pruned_statistics: Option, + expected_results: usize, +) { + let output = ContextWithParquet::new(case_data_type, RowGroup) + .await + .query(sql) + .await; + + println!("{}", output.description()); + assert_eq!(output.predicate_evaluation_errors(), expected_errors); + assert_eq!( + output.row_groups_pruned_bloom_filter(), + expected_row_group_pruned_sbbf + ); + assert_eq!( + output.row_groups_pruned_statistics(), + expected_row_group_pruned_statistics + ); + assert_eq!( + output.result_rows, + expected_results, + "{}", + output.description() + ); +} + #[tokio::test] async fn prune_timestamps_nanos() { test_prune( @@ -336,16 +369,38 @@ async fn prune_int32_eq_in_list() { #[tokio::test] async fn prune_int32_eq_in_list_2() { // result of sql "SELECT * FROM t where in (1000)", prune all - test_prune( + // test whether statistics works + test_prune_verbose( Scenario::Int32, "SELECT * FROM t where i in (1000)", Some(0), + Some(0), Some(4), 0, ) .await; } +#[tokio::test] +async fn prune_int32_eq_large_in_list() { + // result of sql "SELECT * FROM t where i in (2050...2582)", prune all + // test whether sbbf works + test_prune_verbose( + Scenario::Int32Range, + format!( + "SELECT * FROM t where i in ({})", + (200050..200082).join(",") + ) + .as_str(), + Some(0), + Some(1), + // we don't support pruning by statistics for in_list with more than 20 elements currently + Some(0), + 0, + ) + .await; +} + #[tokio::test] async fn prune_int32_eq_in_list_negated() { // result of sql "SELECT * FROM t where not in (1)" prune nothing diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/path_partition.rs index abe6ab283aff..dd8eb52f67c7 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/path_partition.rs @@ -168,9 +168,9 @@ async fn parquet_distinct_partition_col() -> Result<()> { assert_eq!(min_limit, resulting_limit); let s = ScalarValue::try_from_array(results[0].column(1), 0)?; - let month = match s { - ScalarValue::Utf8(Some(month)) => month, - s => panic!("Expected month as Utf8 found {s:?}"), + let month = match extract_as_utf(&s) { + Some(month) => month, + s => panic!("Expected month as Dict(_, Utf8) found {s:?}"), }; let sql_on_partition_boundary = format!( @@ -191,6 +191,15 @@ async fn parquet_distinct_partition_col() -> Result<()> { Ok(()) } +fn extract_as_utf(v: &ScalarValue) -> Option { + if let ScalarValue::Dictionary(_, v) = v { + if let ScalarValue::Utf8(v) = v.as_ref() { + return v.clone(); + } + } + None +} + #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 37f8cefc9080..a1d9a02cf6b1 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -738,7 +738,8 @@ async fn parquet_explain_analyze() { // should contain aggregated stats assert_contains!(&formatted, "output_rows=8"); - assert_contains!(&formatted, "row_groups_pruned=0"); + assert_contains!(&formatted, "row_groups_pruned_bloom_filter=0"); + assert_contains!(&formatted, "row_groups_pruned_statistics=0"); } #[tokio::test] @@ -754,7 +755,8 @@ async fn parquet_explain_analyze_verbose() { .to_string(); // should contain the raw per file stats (with the label) - assert_contains!(&formatted, "row_groups_pruned{partition=0"); + assert_contains!(&formatted, "row_groups_pruned_bloom_filter{partition=0"); + assert_contains!(&formatted, "row_groups_pruned_statistics{partition=0"); } #[tokio::test] diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 8ac0e3e5ef19..e8a3d27c089a 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -19,55 +19,6 @@ use datafusion::datasource::empty::EmptyTable; use super::*; -#[tokio::test] -async fn test_boolean_expressions() -> Result<()> { - test_expression!("true", "true"); - test_expression!("false", "false"); - test_expression!("false = false", "true"); - test_expression!("true = false", "false"); - Ok(()) -} - -#[tokio::test] -async fn test_mathematical_expressions_with_null() -> Result<()> { - test_expression!("sqrt(NULL)", "NULL"); - test_expression!("cbrt(NULL)", "NULL"); - test_expression!("sin(NULL)", "NULL"); - test_expression!("cos(NULL)", "NULL"); - test_expression!("tan(NULL)", "NULL"); - test_expression!("asin(NULL)", "NULL"); - test_expression!("acos(NULL)", "NULL"); - test_expression!("atan(NULL)", "NULL"); - test_expression!("sinh(NULL)", "NULL"); - test_expression!("cosh(NULL)", "NULL"); - test_expression!("tanh(NULL)", "NULL"); - test_expression!("asinh(NULL)", "NULL"); - test_expression!("acosh(NULL)", "NULL"); - test_expression!("atanh(NULL)", "NULL"); - test_expression!("floor(NULL)", "NULL"); - test_expression!("ceil(NULL)", "NULL"); - test_expression!("round(NULL)", "NULL"); - test_expression!("trunc(NULL)", "NULL"); - test_expression!("abs(NULL)", "NULL"); - test_expression!("signum(NULL)", "NULL"); - test_expression!("exp(NULL)", "NULL"); - test_expression!("ln(NULL)", "NULL"); - test_expression!("log2(NULL)", "NULL"); - test_expression!("log10(NULL)", "NULL"); - test_expression!("power(NULL, 2)", "NULL"); - test_expression!("power(NULL, NULL)", "NULL"); - test_expression!("power(2, NULL)", "NULL"); - test_expression!("atan2(NULL, NULL)", "NULL"); - test_expression!("atan2(1, NULL)", "NULL"); - test_expression!("atan2(NULL, 1)", "NULL"); - test_expression!("nanvl(NULL, NULL)", "NULL"); - test_expression!("nanvl(1, NULL)", "NULL"); - test_expression!("nanvl(NULL, 1)", "NULL"); - test_expression!("isnan(NULL)", "NULL"); - test_expression!("iszero(NULL)", "NULL"); - Ok(()) -} - #[tokio::test] #[cfg_attr(not(feature = "crypto_expressions"), ignore)] async fn test_encoding_expressions() -> Result<()> { @@ -128,14 +79,6 @@ async fn test_encoding_expressions() -> Result<()> { Ok(()) } -#[should_panic(expected = "Invalid timezone \\\"Foo\\\": 'Foo' is not a valid timezone")] -#[tokio::test] -async fn test_array_cast_invalid_timezone_will_panic() { - let ctx = SessionContext::new(); - let sql = "SELECT arrow_cast('2021-01-02T03:04:00', 'Timestamp(Nanosecond, Some(\"Foo\"))')"; - execute(&ctx, sql).await; -} - #[tokio::test] #[cfg_attr(not(feature = "crypto_expressions"), ignore)] async fn test_crypto_expressions() -> Result<()> { @@ -212,242 +155,6 @@ async fn test_crypto_expressions() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_array_index() -> Result<()> { - // By default PostgreSQL uses a one-based numbering convention for arrays, that is, an array of n elements starts with array[1] and ends with array[n] - test_expression!("([5,4,3,2,1])[1]", "5"); - test_expression!("([5,4,3,2,1])[2]", "4"); - test_expression!("([5,4,3,2,1])[5]", "1"); - test_expression!("([[1, 2], [2, 3], [3,4]])[1]", "[1, 2]"); - test_expression!("([[1, 2], [2, 3], [3,4]])[3]", "[3, 4]"); - test_expression!("([[1, 2], [2, 3], [3,4]])[1][1]", "1"); - test_expression!("([[1, 2], [2, 3], [3,4]])[2][2]", "3"); - test_expression!("([[1, 2], [2, 3], [3,4]])[3][2]", "4"); - // out of bounds - test_expression!("([5,4,3,2,1])[0]", "NULL"); - test_expression!("([5,4,3,2,1])[6]", "NULL"); - // test_expression!("([5,4,3,2,1])[-1]", "NULL"); - test_expression!("([5,4,3,2,1])[100]", "NULL"); - - Ok(()) -} - -#[tokio::test] -async fn test_array_literals() -> Result<()> { - // Named, just another syntax - test_expression!("ARRAY[1,2,3,4,5]", "[1, 2, 3, 4, 5]"); - // Unnamed variant - test_expression!("[1,2,3,4,5]", "[1, 2, 3, 4, 5]"); - test_expression!("[true, false]", "[true, false]"); - test_expression!("['str1', 'str2']", "[str1, str2]"); - test_expression!("[[1,2], [3,4]]", "[[1, 2], [3, 4]]"); - - // TODO: Not supported in parser, uncomment when it will be available - // test_expression!( - // "[]", - // "[]" - // ); - - Ok(()) -} - -#[tokio::test] -async fn test_struct_literals() -> Result<()> { - test_expression!("STRUCT(1,2,3,4,5)", "{c0: 1, c1: 2, c2: 3, c3: 4, c4: 5}"); - test_expression!("STRUCT(Null)", "{c0: }"); - test_expression!("STRUCT(2)", "{c0: 2}"); - test_expression!("STRUCT('1',Null)", "{c0: 1, c1: }"); - test_expression!("STRUCT(true, false)", "{c0: true, c1: false}"); - test_expression!("STRUCT('str1', 'str2')", "{c0: str1, c1: str2}"); - - Ok(()) -} - -#[tokio::test] -async fn binary_bitwise_shift() -> Result<()> { - test_expression!("2 << 10", "2048"); - test_expression!("2048 >> 10", "2"); - test_expression!("2048 << NULL", "NULL"); - test_expression!("2048 >> NULL", "NULL"); - - Ok(()) -} - -#[tokio::test] -async fn test_interval_expressions() -> Result<()> { - // day nano intervals - test_expression!( - "interval '1'", - "0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs" - ); - test_expression!( - "interval '1 second'", - "0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs" - ); - test_expression!( - "interval '500 milliseconds'", - "0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs" - ); - test_expression!( - "interval '5 second'", - "0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs" - ); - test_expression!( - "interval '0.5 minute'", - "0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs" - ); - // https://github.com/apache/arrow-rs/issues/4424 - // test_expression!( - // "interval '.5 minute'", - // "0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs" - // ); - test_expression!( - "interval '5 minute'", - "0 years 0 mons 0 days 0 hours 5 mins 0.000000000 secs" - ); - test_expression!( - "interval '5 minute 1 second'", - "0 years 0 mons 0 days 0 hours 5 mins 1.000000000 secs" - ); - test_expression!( - "interval '1 hour'", - "0 years 0 mons 0 days 1 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '5 hour'", - "0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 day'", - "0 years 0 mons 1 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 week'", - "0 years 0 mons 7 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '2 weeks'", - "0 years 0 mons 14 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 day 1'", - "0 years 0 mons 1 days 0 hours 0 mins 1.000000000 secs" - ); - test_expression!( - "interval '0.5'", - "0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs" - ); - test_expression!( - "interval '0.5 day 1'", - "0 years 0 mons 0 days 12 hours 0 mins 1.000000000 secs" - ); - test_expression!( - "interval '0.49 day'", - "0 years 0 mons 0 days 11 hours 45 mins 36.000000000 secs" - ); - test_expression!( - "interval '0.499 day'", - "0 years 0 mons 0 days 11 hours 58 mins 33.600000000 secs" - ); - test_expression!( - "interval '0.4999 day'", - "0 years 0 mons 0 days 11 hours 59 mins 51.360000000 secs" - ); - test_expression!( - "interval '0.49999 day'", - "0 years 0 mons 0 days 11 hours 59 mins 59.136000000 secs" - ); - test_expression!( - "interval '0.49999999999 day'", - "0 years 0 mons 0 days 11 hours 59 mins 59.999999136 secs" - ); - test_expression!( - "interval '5 day'", - "0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs" - ); - // Hour is ignored, this matches PostgreSQL - test_expression!( - "interval '5 day' hour", - "0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds'", - "0 years 0 mons 5 days 4 hours 3 mins 2.100000000 secs" - ); - // month intervals - test_expression!( - "interval '0.5 month'", - "0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '0.5' month", - "0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 month'", - "0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1' MONTH", - "0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '5 month'", - "0 years 5 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '13 month'", - "0 years 13 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '0.5 year'", - "0 years 6 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 year'", - "0 years 12 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 decade'", - "0 years 120 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '2 decades'", - "0 years 240 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 century'", - "0 years 1200 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '2 year'", - "0 years 24 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '2' year", - "0 years 24 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - // complex - test_expression!( - "interval '1 year 1 day'", - "0 years 12 mons 1 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 year 1 day 1 hour'", - "0 years 12 mons 1 days 1 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 year 1 day 1 hour 1 minute'", - "0 years 12 mons 1 days 1 hours 1 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 year 1 day 1 hour 1 minute 1 second'", - "0 years 12 mons 1 days 1 hours 1 mins 1.000000000 secs" - ); - - Ok(()) -} - #[cfg(feature = "unicode_expressions")] #[tokio::test] async fn test_substring_expr() -> Result<()> { @@ -458,108 +165,6 @@ async fn test_substring_expr() -> Result<()> { Ok(()) } -/// Test string expressions test split into two batches -/// to prevent stack overflow error -#[tokio::test] -async fn test_string_expressions_batch1() -> Result<()> { - test_expression!("ascii('')", "0"); - test_expression!("ascii('x')", "120"); - test_expression!("ascii(NULL)", "NULL"); - test_expression!("bit_length('')", "0"); - test_expression!("bit_length('chars')", "40"); - test_expression!("bit_length('josé')", "40"); - test_expression!("bit_length(NULL)", "NULL"); - test_expression!("btrim(' xyxtrimyyx ', NULL)", "NULL"); - test_expression!("btrim(' xyxtrimyyx ')", "xyxtrimyyx"); - test_expression!("btrim('\n xyxtrimyyx \n')", "\n xyxtrimyyx \n"); - test_expression!("btrim('xyxtrimyyx', 'xyz')", "trim"); - test_expression!("btrim('\nxyxtrimyyx\n', 'xyz\n')", "trim"); - test_expression!("btrim(NULL, 'xyz')", "NULL"); - test_expression!("chr(CAST(120 AS int))", "x"); - test_expression!("chr(CAST(128175 AS int))", "💯"); - test_expression!("chr(CAST(NULL AS int))", "NULL"); - test_expression!("concat('a','b','c')", "abc"); - test_expression!("concat('abcde', 2, NULL, 22)", "abcde222"); - test_expression!("concat(NULL)", ""); - test_expression!("concat_ws(',', 'abcde', 2, NULL, 22)", "abcde,2,22"); - test_expression!("concat_ws('|','a','b','c')", "a|b|c"); - test_expression!("concat_ws('|',NULL)", ""); - test_expression!("concat_ws(NULL,'a',NULL,'b','c')", "NULL"); - test_expression!("concat_ws('|','a',NULL)", "a"); - test_expression!("concat_ws('|','a',NULL,NULL)", "a"); - test_expression!("initcap('')", ""); - test_expression!("initcap('hi THOMAS')", "Hi Thomas"); - test_expression!("initcap(NULL)", "NULL"); - test_expression!("lower('')", ""); - test_expression!("lower('TOM')", "tom"); - test_expression!("lower(NULL)", "NULL"); - test_expression!("ltrim(' zzzytest ', NULL)", "NULL"); - test_expression!("ltrim(' zzzytest ')", "zzzytest "); - test_expression!("ltrim('zzzytest', 'xyz')", "test"); - test_expression!("ltrim(NULL, 'xyz')", "NULL"); - test_expression!("octet_length('')", "0"); - test_expression!("octet_length('chars')", "5"); - test_expression!("octet_length('josé')", "5"); - test_expression!("octet_length(NULL)", "NULL"); - test_expression!("repeat('Pg', 4)", "PgPgPgPg"); - test_expression!("repeat('Pg', CAST(NULL AS INT))", "NULL"); - test_expression!("repeat(NULL, 4)", "NULL"); - test_expression!("replace('abcdefabcdef', 'cd', 'XX')", "abXXefabXXef"); - test_expression!("replace('abcdefabcdef', 'cd', NULL)", "NULL"); - test_expression!("replace('abcdefabcdef', 'notmatch', 'XX')", "abcdefabcdef"); - test_expression!("replace('abcdefabcdef', NULL, 'XX')", "NULL"); - test_expression!("replace(NULL, 'cd', 'XX')", "NULL"); - test_expression!("rtrim(' testxxzx ')", " testxxzx"); - test_expression!("rtrim(' zzzytest ', NULL)", "NULL"); - test_expression!("rtrim('testxxzx', 'xyz')", "test"); - test_expression!("rtrim(NULL, 'xyz')", "NULL"); - Ok(()) -} - -/// Test string expressions test split into two batches -/// to prevent stack overflow error -#[tokio::test] -async fn test_string_expressions_batch2() -> Result<()> { - test_expression!("split_part('abc~@~def~@~ghi', '~@~', 2)", "def"); - test_expression!("split_part('abc~@~def~@~ghi', '~@~', 20)", ""); - test_expression!("split_part(NULL, '~@~', 20)", "NULL"); - test_expression!("split_part('abc~@~def~@~ghi', NULL, 20)", "NULL"); - test_expression!( - "split_part('abc~@~def~@~ghi', '~@~', CAST(NULL AS INT))", - "NULL" - ); - test_expression!("starts_with('alphabet', 'alph')", "true"); - test_expression!("starts_with('alphabet', 'blph')", "false"); - test_expression!("starts_with(NULL, 'blph')", "NULL"); - test_expression!("starts_with('alphabet', NULL)", "NULL"); - test_expression!("to_hex(2147483647)", "7fffffff"); - test_expression!("to_hex(9223372036854775807)", "7fffffffffffffff"); - test_expression!("to_hex(CAST(NULL AS int))", "NULL"); - test_expression!("trim(' tom ')", "tom"); - test_expression!("trim(LEADING ' tom ')", "tom "); - test_expression!("trim(TRAILING ' tom ')", " tom"); - test_expression!("trim(BOTH ' tom ')", "tom"); - test_expression!("trim(LEADING ' ' FROM ' tom ')", "tom "); - test_expression!("trim(TRAILING ' ' FROM ' tom ')", " tom"); - test_expression!("trim(BOTH ' ' FROM ' tom ')", "tom"); - test_expression!("trim(' ' FROM ' tom ')", "tom"); - test_expression!("trim(LEADING 'x' FROM 'xxxtomxxx')", "tomxxx"); - test_expression!("trim(TRAILING 'x' FROM 'xxxtomxxx')", "xxxtom"); - test_expression!("trim(BOTH 'x' FROM 'xxxtomxx')", "tom"); - test_expression!("trim('x' FROM 'xxxtomxx')", "tom"); - test_expression!("trim(LEADING 'xy' FROM 'xyxabcxyzdefxyx')", "abcxyzdefxyx"); - test_expression!("trim(TRAILING 'xy' FROM 'xyxabcxyzdefxyx')", "xyxabcxyzdef"); - test_expression!("trim(BOTH 'xy' FROM 'xyxabcxyzdefxyx')", "abcxyzdef"); - test_expression!("trim('xy' FROM 'xyxabcxyzdefxyx')", "abcxyzdef"); - test_expression!("trim(' tom')", "tom"); - test_expression!("trim('')", ""); - test_expression!("trim('tom ')", "tom"); - test_expression!("upper('')", ""); - test_expression!("upper('tom')", "TOM"); - test_expression!("upper(NULL)", "NULL"); - Ok(()) -} - #[tokio::test] #[cfg_attr(not(feature = "regex_expressions"), ignore)] async fn test_regex_expressions() -> Result<()> { @@ -593,329 +198,6 @@ async fn test_regex_expressions() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_cast_expressions() -> Result<()> { - test_expression!("CAST('0' AS INT)", "0"); - test_expression!("CAST(NULL AS INT)", "NULL"); - test_expression!("TRY_CAST('0' AS INT)", "0"); - test_expression!("TRY_CAST('x' AS INT)", "NULL"); - Ok(()) -} - -#[tokio::test] -#[ignore] -// issue: https://github.com/apache/arrow-datafusion/issues/6596 -async fn test_array_cast_expressions() -> Result<()> { - test_expression!("CAST([1,2,3,4] AS INT[])", "[1, 2, 3, 4]"); - test_expression!( - "CAST([1,2,3,4] AS NUMERIC(10,4)[])", - "[1.0000, 2.0000, 3.0000, 4.0000]" - ); - Ok(()) -} - -#[tokio::test] -async fn test_random_expression() -> Result<()> { - let ctx = SessionContext::new(); - let sql = "SELECT random() r1"; - let actual = execute(&ctx, sql).await; - let r1 = actual[0][0].parse::().unwrap(); - assert!(0.0 <= r1); - assert!(r1 < 1.0); - Ok(()) -} - -#[tokio::test] -async fn test_uuid_expression() -> Result<()> { - let ctx = SessionContext::new(); - let sql = "SELECT uuid()"; - let actual = execute(&ctx, sql).await; - let uuid = actual[0][0].parse::().unwrap(); - assert_eq!(uuid.get_version_num(), 4); - Ok(()) -} - -#[tokio::test] -async fn test_extract_date_part() -> Result<()> { - test_expression!("date_part('YEAR', CAST('2000-01-01' AS DATE))", "2000.0"); - test_expression!( - "EXTRACT(year FROM timestamp '2020-09-08T12:00:00+00:00')", - "2020.0" - ); - test_expression!("date_part('QUARTER', CAST('2000-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(quarter FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "3.0" - ); - test_expression!("date_part('MONTH', CAST('2000-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(month FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "9.0" - ); - test_expression!("date_part('WEEK', CAST('2003-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(WEEK FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "37.0" - ); - test_expression!("date_part('DAY', CAST('2000-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(day FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "8.0" - ); - test_expression!("date_part('DOY', CAST('2000-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(doy FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "252.0" - ); - test_expression!("date_part('DOW', CAST('2000-01-01' AS DATE))", "6.0"); - test_expression!( - "EXTRACT(dow FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "2.0" - ); - test_expression!("date_part('HOUR', CAST('2000-01-01' AS DATE))", "0.0"); - test_expression!( - "EXTRACT(hour FROM to_timestamp('2020-09-08T12:03:03+00:00'))", - "12.0" - ); - test_expression!( - "EXTRACT(minute FROM to_timestamp('2020-09-08T12:12:00+00:00'))", - "12.0" - ); - test_expression!( - "date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00'))", - "12.0" - ); - test_expression!( - "EXTRACT(second FROM timestamp '2020-09-08T12:00:12.12345678+00:00')", - "12.12345678" - ); - test_expression!( - "EXTRACT(millisecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00')", - "12123.45678" - ); - test_expression!( - "EXTRACT(microsecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00')", - "12123456.78" - ); - test_expression!( - "EXTRACT(nanosecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00')", - "1.212345678e10" - ); - test_expression!( - "date_part('second', timestamp '2020-09-08T12:00:12.12345678+00:00')", - "12.12345678" - ); - test_expression!( - "date_part('millisecond', timestamp '2020-09-08T12:00:12.12345678+00:00')", - "12123.45678" - ); - test_expression!( - "date_part('microsecond', timestamp '2020-09-08T12:00:12.12345678+00:00')", - "12123456.78" - ); - test_expression!( - "date_part('nanosecond', timestamp '2020-09-08T12:00:12.12345678+00:00')", - "1.212345678e10" - ); - - // Keep precision when coercing Utf8 to Timestamp - test_expression!( - "date_part('second', '2020-09-08T12:00:12.12345678+00:00')", - "12.12345678" - ); - test_expression!( - "date_part('millisecond', '2020-09-08T12:00:12.12345678+00:00')", - "12123.45678" - ); - test_expression!( - "date_part('microsecond', '2020-09-08T12:00:12.12345678+00:00')", - "12123456.78" - ); - test_expression!( - "date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00')", - "1.212345678e10" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_extract_epoch() -> Result<()> { - // timestamp - test_expression!( - "extract(epoch from '1870-01-01T07:29:10.256'::timestamp)", - "-3155646649.744" - ); - test_expression!( - "extract(epoch from '2000-01-01T00:00:00.000'::timestamp)", - "946684800.0" - ); - test_expression!( - "extract(epoch from to_timestamp('2000-01-01T00:00:00+00:00'))", - "946684800.0" - ); - test_expression!("extract(epoch from NULL::timestamp)", "NULL"); - // date - test_expression!( - "extract(epoch from arrow_cast('1970-01-01', 'Date32'))", - "0.0" - ); - test_expression!( - "extract(epoch from arrow_cast('1970-01-02', 'Date32'))", - "86400.0" - ); - test_expression!( - "extract(epoch from arrow_cast('1970-01-11', 'Date32'))", - "864000.0" - ); - test_expression!( - "extract(epoch from arrow_cast('1969-12-31', 'Date32'))", - "-86400.0" - ); - test_expression!( - "extract(epoch from arrow_cast('1970-01-01', 'Date64'))", - "0.0" - ); - test_expression!( - "extract(epoch from arrow_cast('1970-01-02', 'Date64'))", - "86400.0" - ); - test_expression!( - "extract(epoch from arrow_cast('1970-01-11', 'Date64'))", - "864000.0" - ); - test_expression!( - "extract(epoch from arrow_cast('1969-12-31', 'Date64'))", - "-86400.0" - ); - Ok(()) -} - -#[tokio::test] -async fn test_extract_date_part_func() -> Result<()> { - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "year" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "quarter" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "month" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "week" - ), - "true" - ); - test_expression!( - format!("(date_part('{0}', now()) = EXTRACT({0} FROM now()))", "day"), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "hour" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "minute" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "second" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "millisecond" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "microsecond" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "nanosecond" - ), - "true" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_in_list_scalar() -> Result<()> { - test_expression!("'a' IN ('a','b')", "true"); - test_expression!("'c' IN ('a','b')", "false"); - test_expression!("'c' NOT IN ('a','b')", "true"); - test_expression!("'a' NOT IN ('a','b')", "false"); - test_expression!("NULL IN ('a','b')", "NULL"); - test_expression!("NULL NOT IN ('a','b')", "NULL"); - test_expression!("'a' IN ('a','b',NULL)", "true"); - test_expression!("'c' IN ('a','b',NULL)", "NULL"); - test_expression!("'a' NOT IN ('a','b',NULL)", "false"); - test_expression!("'c' NOT IN ('a','b',NULL)", "NULL"); - test_expression!("0 IN (0,1,2)", "true"); - test_expression!("3 IN (0,1,2)", "false"); - test_expression!("3 NOT IN (0,1,2)", "true"); - test_expression!("0 NOT IN (0,1,2)", "false"); - test_expression!("NULL IN (0,1,2)", "NULL"); - test_expression!("NULL NOT IN (0,1,2)", "NULL"); - test_expression!("0 IN (0,1,2,NULL)", "true"); - test_expression!("3 IN (0,1,2,NULL)", "NULL"); - test_expression!("0 NOT IN (0,1,2,NULL)", "false"); - test_expression!("3 NOT IN (0,1,2,NULL)", "NULL"); - test_expression!("0.0 IN (0.0,0.1,0.2)", "true"); - test_expression!("0.3 IN (0.0,0.1,0.2)", "false"); - test_expression!("0.3 NOT IN (0.0,0.1,0.2)", "true"); - test_expression!("0.0 NOT IN (0.0,0.1,0.2)", "false"); - test_expression!("NULL IN (0.0,0.1,0.2)", "NULL"); - test_expression!("NULL NOT IN (0.0,0.1,0.2)", "NULL"); - test_expression!("0.0 IN (0.0,0.1,0.2,NULL)", "true"); - test_expression!("0.3 IN (0.0,0.1,0.2,NULL)", "NULL"); - test_expression!("0.0 NOT IN (0.0,0.1,0.2,NULL)", "false"); - test_expression!("0.3 NOT IN (0.0,0.1,0.2,NULL)", "NULL"); - test_expression!("'1' IN ('a','b',1)", "true"); - test_expression!("'2' IN ('a','b',1)", "false"); - test_expression!("'2' NOT IN ('a','b',1)", "true"); - test_expression!("'1' NOT IN ('a','b',1)", "false"); - test_expression!("NULL IN ('a','b',1)", "NULL"); - test_expression!("NULL NOT IN ('a','b',1)", "NULL"); - test_expression!("'1' IN ('a','b',NULL,1)", "true"); - test_expression!("'2' IN ('a','b',NULL,1)", "NULL"); - test_expression!("'1' NOT IN ('a','b',NULL,1)", "false"); - test_expression!("'2' NOT IN ('a','b',NULL,1)", "NULL"); - Ok(()) -} - #[tokio::test] async fn csv_query_nullif_divide_by_0() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs deleted file mode 100644 index 58f0ac21d951..000000000000 --- a/datafusion/core/tests/sql/group_by.rs +++ /dev/null @@ -1,253 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::*; -use arrow::util::pretty::pretty_format_batches; -use arrow_schema::{DataType, TimeUnit}; - -#[tokio::test] -async fn group_by_date_trunc() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = SessionContext::new(); - let schema = Arc::new(Schema::new(vec![ - Field::new("c2", DataType::UInt64, false), - Field::new( - "t1", - DataType::Timestamp(TimeUnit::Microsecond, None), - false, - ), - ])); - - // generate a partitioned file - for partition in 0..4 { - let filename = format!("partition-{}.{}", partition, "csv"); - let file_path = tmp_dir.path().join(filename); - let mut file = File::create(file_path)?; - - // generate some data - for i in 0..10 { - let data = format!("{},2020-12-{}T00:00:00.000Z\n", i, i + 10); - file.write_all(data.as_bytes())?; - } - } - - ctx.register_csv( - "test", - tmp_dir.path().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).has_header(false), - ) - .await?; - - let results = plan_and_collect( - &ctx, - "SELECT date_trunc('week', t1) as week, SUM(c2) FROM test GROUP BY date_trunc('week', t1)", - ).await?; - - let expected = [ - "+---------------------+--------------+", - "| week | SUM(test.c2) |", - "+---------------------+--------------+", - "| 2020-12-07T00:00:00 | 24 |", - "| 2020-12-14T00:00:00 | 156 |", - "+---------------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn group_by_limit() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = create_groupby_context(&tmp_dir).await?; - - let sql = "SELECT trace_id, MAX(ts) from traces group by trace_id order by MAX(ts) desc limit 4"; - let dataframe = ctx.sql(sql).await?; - - // ensure we see `lim=[4]` - let physical_plan = dataframe.create_physical_plan().await?; - let mut expected_physical_plan = r#" -GlobalLimitExec: skip=0, fetch=4 - SortExec: TopK(fetch=4), expr=[MAX(traces.ts)@1 DESC] - AggregateExec: mode=Single, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.ts)], lim=[4] - "#.trim().to_string(); - let actual_phys_plan = - format_plan(physical_plan.clone(), &mut expected_physical_plan); - assert_eq!(actual_phys_plan, expected_physical_plan); - - let batches = collect(physical_plan, ctx.task_ctx()).await?; - let expected = r#" -+----------+----------------------+ -| trace_id | MAX(traces.ts) | -+----------+----------------------+ -| 9 | 2020-12-01T00:00:18Z | -| 8 | 2020-12-01T00:00:17Z | -| 7 | 2020-12-01T00:00:16Z | -| 6 | 2020-12-01T00:00:15Z | -+----------+----------------------+ -"# - .trim(); - let actual = format!("{}", pretty_format_batches(&batches)?); - assert_eq!(actual, expected); - - Ok(()) -} - -fn format_plan( - physical_plan: Arc, - expected_phys_plan: &mut String, -) -> String { - let actual_phys_plan = displayable(physical_plan.as_ref()).indent(true).to_string(); - let last_line = actual_phys_plan - .as_str() - .lines() - .last() - .expect("Plan should not be empty"); - - expected_phys_plan.push('\n'); - expected_phys_plan.push_str(last_line); - expected_phys_plan.push('\n'); - actual_phys_plan -} - -async fn create_groupby_context(tmp_dir: &TempDir) -> Result { - let schema = Arc::new(Schema::new(vec![ - Field::new("trace_id", DataType::Utf8, false), - Field::new( - "ts", - DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), - false, - ), - ])); - - // generate a file - let filename = "traces.csv"; - let file_path = tmp_dir.path().join(filename); - let mut file = File::create(file_path)?; - - // generate some data - for trace_id in 0..10 { - for ts in 0..10 { - let ts = trace_id + ts; - let data = format!("\"{trace_id}\",2020-12-01T00:00:{ts:02}.000Z\n"); - file.write_all(data.as_bytes())?; - } - } - - let cfg = SessionConfig::new().with_target_partitions(1); - let ctx = SessionContext::new_with_config(cfg); - ctx.register_csv( - "traces", - tmp_dir.path().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).has_header(false), - ) - .await?; - Ok(ctx) -} - -#[tokio::test] -async fn group_by_dictionary() { - async fn run_test_case() { - let ctx = SessionContext::new(); - - // input data looks like: - // A, 1 - // B, 2 - // A, 2 - // A, 4 - // C, 1 - // A, 1 - - let dict_array: DictionaryArray = - vec!["A", "B", "A", "A", "C", "A"].into_iter().collect(); - let dict_array = Arc::new(dict_array); - - let val_array: Int64Array = vec![1, 2, 2, 4, 1, 1].into(); - let val_array = Arc::new(val_array); - - let schema = Arc::new(Schema::new(vec![ - Field::new("dict", dict_array.data_type().clone(), false), - Field::new("val", val_array.data_type().clone(), false), - ])); - - let batch = - RecordBatch::try_new(schema.clone(), vec![dict_array, val_array]).unwrap(); - - ctx.register_batch("t", batch).unwrap(); - - let results = - plan_and_collect(&ctx, "SELECT dict, count(val) FROM t GROUP BY dict") - .await - .expect("ran plan correctly"); - - let expected = [ - "+------+--------------+", - "| dict | COUNT(t.val) |", - "+------+--------------+", - "| A | 4 |", - "| B | 1 |", - "| C | 1 |", - "+------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - // Now, use dict as an aggregate - let results = - plan_and_collect(&ctx, "SELECT val, count(dict) FROM t GROUP BY val") - .await - .expect("ran plan correctly"); - - let expected = [ - "+-----+---------------+", - "| val | COUNT(t.dict) |", - "+-----+---------------+", - "| 1 | 3 |", - "| 2 | 2 |", - "| 4 | 1 |", - "+-----+---------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - // Now, use dict as an aggregate - let results = plan_and_collect( - &ctx, - "SELECT val, count(distinct dict) FROM t GROUP BY val", - ) - .await - .expect("ran plan correctly"); - - let expected = [ - "+-----+------------------------+", - "| val | COUNT(DISTINCT t.dict) |", - "+-----+------------------------+", - "| 1 | 2 |", - "| 2 | 2 |", - "| 4 | 1 |", - "+-----+------------------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - } - - run_test_case::().await; - run_test_case::().await; - run_test_case::().await; - run_test_case::().await; - run_test_case::().await; - run_test_case::().await; - run_test_case::().await; - run_test_case::().await; -} diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index d1f270b540b5..0cc102002ec3 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -124,6 +124,74 @@ async fn join_change_in_planner() -> Result<()> { [ "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + // " CsvExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], has_header=false", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + // " CsvExec: file_groups={1 group: [[tempdir/right.csv]]}, projection=[a1, a2], has_header=false" + ] + }; + let mut actual: Vec<&str> = formatted.trim().lines().collect(); + // Remove CSV lines + actual.remove(4); + actual.remove(7); + + assert_eq!( + expected, + actual[..], + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + Ok(()) +} + +#[tokio::test] +async fn join_no_order_on_filter() -> Result<()> { + let config = SessionConfig::new().with_target_partitions(8); + let ctx = SessionContext::new_with_config(config); + let tmp_dir = TempDir::new().unwrap(); + let left_file_path = tmp_dir.path().join("left.csv"); + File::create(left_file_path.clone()).unwrap(); + // Create schema + let schema = Arc::new(Schema::new(vec![ + Field::new("a1", DataType::UInt32, false), + Field::new("a2", DataType::UInt32, false), + Field::new("a3", DataType::UInt32, false), + ])); + // Specify the ordering: + let file_sort_order = vec![[datafusion_expr::col("a1")] + .into_iter() + .map(|e| { + let ascending = true; + let nulls_first = false; + e.sort(ascending, nulls_first) + }) + .collect::>()]; + register_unbounded_file_with_ordering( + &ctx, + schema.clone(), + &left_file_path, + "left", + file_sort_order.clone(), + )?; + let right_file_path = tmp_dir.path().join("right.csv"); + File::create(right_file_path.clone()).unwrap(); + register_unbounded_file_with_ordering( + &ctx, + schema, + &right_file_path, + "right", + file_sort_order, + )?; + let sql = "SELECT * FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a3 > t2.a3 + 3 AND t1.a3 < t2.a3 + 10"; + let dataframe = ctx.sql(sql).await?; + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); + let expected = { + [ + "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a3@0 AS Int64) > CAST(a3@1 AS Int64) + 3 AND CAST(a3@0 AS Int64) < CAST(a3@1 AS Int64) + 10", + " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", // " CsvExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], has_header=false", diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 849d85dec6bf..3f52d2aae894 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -76,7 +76,6 @@ pub mod create_drop; pub mod csv_files; pub mod explain_analyze; pub mod expr; -pub mod group_by; pub mod joins; pub mod order; pub mod partitioned_csv; diff --git a/datafusion/core/tests/sql/timestamp.rs b/datafusion/core/tests/sql/timestamp.rs index ada66503a181..e74857cb313b 100644 --- a/datafusion/core/tests/sql/timestamp.rs +++ b/datafusion/core/tests/sql/timestamp.rs @@ -18,470 +18,6 @@ use super::*; use std::ops::Add; -#[tokio::test] -async fn test_current_timestamp_expressions() -> Result<()> { - let t1 = chrono::Utc::now().timestamp(); - let ctx = SessionContext::new(); - let actual = execute(&ctx, "SELECT NOW(), NOW() as t2").await; - let res1 = actual[0][0].as_str(); - let res2 = actual[0][1].as_str(); - let t3 = Utc::now().timestamp(); - let t2_naive = DateTime::parse_from_rfc3339(res1).unwrap(); - - let t2 = t2_naive.timestamp(); - assert!(t1 <= t2 && t2 <= t3); - assert_eq!(res2, res1); - - Ok(()) -} - -#[tokio::test] -async fn test_now_in_same_stmt_using_sql_function() -> Result<()> { - let ctx = SessionContext::new(); - - let df1 = ctx.sql("select now(), now() as now2").await?; - let result = result_vec(&df1.collect().await?); - assert_eq!(result[0][0], result[0][1]); - - Ok(()) -} - -#[tokio::test] -async fn test_now_across_statements() -> Result<()> { - let ctx = SessionContext::new(); - - let actual1 = execute(&ctx, "SELECT NOW()").await; - let res1 = actual1[0][0].as_str(); - - let actual2 = execute(&ctx, "SELECT NOW()").await; - let res2 = actual2[0][0].as_str(); - - assert!(res1 < res2); - - Ok(()) -} - -#[tokio::test] -async fn test_now_across_statements_using_sql_function() -> Result<()> { - let ctx = SessionContext::new(); - - let df1 = ctx.sql("select now()").await?; - let rb1 = df1.collect().await?; - let result1 = result_vec(&rb1); - let res1 = result1[0][0].as_str(); - - let df2 = ctx.sql("select now()").await?; - let rb2 = df2.collect().await?; - let result2 = result_vec(&rb2); - let res2 = result2[0][0].as_str(); - - assert!(res1 < res2); - - Ok(()) -} - -#[tokio::test] -async fn test_now_dataframe_api() -> Result<()> { - let ctx = SessionContext::new(); - let df = ctx.sql("select 1").await?; // use this to get a DataFrame - let df = df.select(vec![now(), now().alias("now2")])?; - let result = result_vec(&df.collect().await?); - assert_eq!(result[0][0], result[0][1]); - - Ok(()) -} - -#[tokio::test] -async fn test_now_dataframe_api_across_statements() -> Result<()> { - let ctx = SessionContext::new(); - let df = ctx.sql("select 1").await?; // use this to get a DataFrame - let df = df.select(vec![now()])?; - let result = result_vec(&df.collect().await?); - - let df = ctx.sql("select 1").await?; - let df = df.select(vec![now()])?; - let result2 = result_vec(&df.collect().await?); - - assert_ne!(result[0][0], result2[0][0]); - - Ok(()) -} - -#[tokio::test] -async fn test_now_in_view() -> Result<()> { - let ctx = SessionContext::new(); - let _df = ctx - .sql("create or replace view test_now as select now()") - .await? - .collect() - .await?; - - let df = ctx.sql("select * from test_now").await?; - let result = result_vec(&df.collect().await?); - - let df1 = ctx.sql("select * from test_now").await?; - let result2 = result_vec(&df1.collect().await?); - - assert_ne!(result[0][0], result2[0][0]); - - Ok(()) -} - -#[tokio::test] -async fn timestamp_minmax() -> Result<()> { - let ctx = SessionContext::new(); - let table_a = make_timestamp_tz_table::(None)?; - let table_b = - make_timestamp_tz_table::(Some("+00:00".into()))?; - ctx.register_table("table_a", table_a)?; - ctx.register_table("table_b", table_b)?; - - let sql = "SELECT MIN(table_a.ts), MAX(table_b.ts) FROM table_a, table_b"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------------------+-----------------------------+", - "| MIN(table_a.ts) | MAX(table_b.ts) |", - "+-------------------------+-----------------------------+", - "| 2020-09-08T11:42:29.190 | 2020-09-08T13:42:29.190855Z |", - "+-------------------------+-----------------------------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn timestamp_coercion() -> Result<()> { - { - let ctx = SessionContext::new(); - let table_a = - make_timestamp_tz_table::(Some("+00:00".into()))?; - let table_b = - make_timestamp_tz_table::(Some("+00:00".into()))?; - ctx.register_table("table_a", table_a)?; - ctx.register_table("table_b", table_b)?; - - let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b order by table_a.ts desc, table_b.ts desc"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+----------------------+--------------------------+-------------------------+", - "| ts | ts | table_a.ts = table_b.ts |", - "+----------------------+--------------------------+-------------------------+", - "| 2020-09-08T13:42:29Z | 2020-09-08T13:42:29.190Z | true |", - "| 2020-09-08T13:42:29Z | 2020-09-08T12:42:29.190Z | false |", - "| 2020-09-08T13:42:29Z | 2020-09-08T11:42:29.190Z | false |", - "| 2020-09-08T12:42:29Z | 2020-09-08T13:42:29.190Z | false |", - "| 2020-09-08T12:42:29Z | 2020-09-08T12:42:29.190Z | true |", - "| 2020-09-08T12:42:29Z | 2020-09-08T11:42:29.190Z | false |", - "| 2020-09-08T11:42:29Z | 2020-09-08T13:42:29.190Z | false |", - "| 2020-09-08T11:42:29Z | 2020-09-08T12:42:29.190Z | false |", - "| 2020-09-08T11:42:29Z | 2020-09-08T11:42:29.190Z | true |", - "+----------------------+--------------------------+-------------------------+", - ]; - assert_batches_eq!(expected, &actual); - } - - { - let ctx = SessionContext::new(); - let table_a = make_timestamp_table::()?; - let table_b = make_timestamp_table::()?; - ctx.register_table("table_a", table_a)?; - ctx.register_table("table_b", table_b)?; - - let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b order by table_a.ts desc, table_b.ts desc"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+---------------------+----------------------------+-------------------------+", - "| ts | ts | table_a.ts = table_b.ts |", - "+---------------------+----------------------------+-------------------------+", - "| 2020-09-08T13:42:29 | 2020-09-08T13:42:29.190855 | true |", - "| 2020-09-08T13:42:29 | 2020-09-08T12:42:29.190855 | false |", - "| 2020-09-08T13:42:29 | 2020-09-08T11:42:29.190855 | false |", - "| 2020-09-08T12:42:29 | 2020-09-08T13:42:29.190855 | false |", - "| 2020-09-08T12:42:29 | 2020-09-08T12:42:29.190855 | true |", - "| 2020-09-08T12:42:29 | 2020-09-08T11:42:29.190855 | false |", - "| 2020-09-08T11:42:29 | 2020-09-08T13:42:29.190855 | false |", - "| 2020-09-08T11:42:29 | 2020-09-08T12:42:29.190855 | false |", - "| 2020-09-08T11:42:29 | 2020-09-08T11:42:29.190855 | true |", - "+---------------------+----------------------------+-------------------------+", - - ]; - assert_batches_eq!(expected, &actual); - } - - { - let ctx = SessionContext::new(); - let table_a = make_timestamp_table::()?; - let table_b = make_timestamp_table::()?; - ctx.register_table("table_a", table_a)?; - ctx.register_table("table_b", table_b)?; - - let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b order by table_a.ts desc, table_b.ts desc"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+---------------------+----------------------------+-------------------------+", - "| ts | ts | table_a.ts = table_b.ts |", - "+---------------------+----------------------------+-------------------------+", - "| 2020-09-08T13:42:29 | 2020-09-08T13:42:29.190855 | true |", - "| 2020-09-08T13:42:29 | 2020-09-08T12:42:29.190855 | false |", - "| 2020-09-08T13:42:29 | 2020-09-08T11:42:29.190855 | false |", - "| 2020-09-08T12:42:29 | 2020-09-08T13:42:29.190855 | false |", - "| 2020-09-08T12:42:29 | 2020-09-08T12:42:29.190855 | true |", - "| 2020-09-08T12:42:29 | 2020-09-08T11:42:29.190855 | false |", - "| 2020-09-08T11:42:29 | 2020-09-08T13:42:29.190855 | false |", - "| 2020-09-08T11:42:29 | 2020-09-08T12:42:29.190855 | false |", - "| 2020-09-08T11:42:29 | 2020-09-08T11:42:29.190855 | true |", - "+---------------------+----------------------------+-------------------------+", - ]; - assert_batches_eq!(expected, &actual); - } - - { - let ctx = SessionContext::new(); - let table_a = make_timestamp_table::()?; - let table_b = make_timestamp_table::()?; - ctx.register_table("table_a", table_a)?; - ctx.register_table("table_b", table_b)?; - - let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b order by table_a.ts desc, table_b.ts desc"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------------------------+---------------------+-------------------------+", - "| ts | ts | table_a.ts = table_b.ts |", - "+-------------------------+---------------------+-------------------------+", - "| 2020-09-08T13:42:29.190 | 2020-09-08T13:42:29 | true |", - "| 2020-09-08T13:42:29.190 | 2020-09-08T12:42:29 | false |", - "| 2020-09-08T13:42:29.190 | 2020-09-08T11:42:29 | false |", - "| 2020-09-08T12:42:29.190 | 2020-09-08T13:42:29 | false |", - "| 2020-09-08T12:42:29.190 | 2020-09-08T12:42:29 | true |", - "| 2020-09-08T12:42:29.190 | 2020-09-08T11:42:29 | false |", - "| 2020-09-08T11:42:29.190 | 2020-09-08T13:42:29 | false |", - "| 2020-09-08T11:42:29.190 | 2020-09-08T12:42:29 | false |", - "| 2020-09-08T11:42:29.190 | 2020-09-08T11:42:29 | true |", - "+-------------------------+---------------------+-------------------------+", - ]; - assert_batches_eq!(expected, &actual); - } - - { - let ctx = SessionContext::new(); - let table_a = make_timestamp_table::()?; - let table_b = make_timestamp_table::()?; - ctx.register_table("table_a", table_a)?; - ctx.register_table("table_b", table_b)?; - - let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b order by table_a.ts desc, table_b.ts desc"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------------------------+----------------------------+-------------------------+", - "| ts | ts | table_a.ts = table_b.ts |", - "+-------------------------+----------------------------+-------------------------+", - "| 2020-09-08T13:42:29.190 | 2020-09-08T13:42:29.190855 | true |", - "| 2020-09-08T13:42:29.190 | 2020-09-08T12:42:29.190855 | false |", - "| 2020-09-08T13:42:29.190 | 2020-09-08T11:42:29.190855 | false |", - "| 2020-09-08T12:42:29.190 | 2020-09-08T13:42:29.190855 | false |", - "| 2020-09-08T12:42:29.190 | 2020-09-08T12:42:29.190855 | true |", - "| 2020-09-08T12:42:29.190 | 2020-09-08T11:42:29.190855 | false |", - "| 2020-09-08T11:42:29.190 | 2020-09-08T13:42:29.190855 | false |", - "| 2020-09-08T11:42:29.190 | 2020-09-08T12:42:29.190855 | false |", - "| 2020-09-08T11:42:29.190 | 2020-09-08T11:42:29.190855 | true |", - "+-------------------------+----------------------------+-------------------------+", - ]; - assert_batches_eq!(expected, &actual); - } - - { - let ctx = SessionContext::new(); - let table_a = make_timestamp_table::()?; - let table_b = make_timestamp_table::()?; - ctx.register_table("table_a", table_a)?; - ctx.register_table("table_b", table_b)?; - - let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b order by table_a.ts desc, table_b.ts desc"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------------------------+----------------------------+-------------------------+", - "| ts | ts | table_a.ts = table_b.ts |", - "+-------------------------+----------------------------+-------------------------+", - "| 2020-09-08T13:42:29.190 | 2020-09-08T13:42:29.190855 | true |", - "| 2020-09-08T13:42:29.190 | 2020-09-08T12:42:29.190855 | false |", - "| 2020-09-08T13:42:29.190 | 2020-09-08T11:42:29.190855 | false |", - "| 2020-09-08T12:42:29.190 | 2020-09-08T13:42:29.190855 | false |", - "| 2020-09-08T12:42:29.190 | 2020-09-08T12:42:29.190855 | true |", - "| 2020-09-08T12:42:29.190 | 2020-09-08T11:42:29.190855 | false |", - "| 2020-09-08T11:42:29.190 | 2020-09-08T13:42:29.190855 | false |", - "| 2020-09-08T11:42:29.190 | 2020-09-08T12:42:29.190855 | false |", - "| 2020-09-08T11:42:29.190 | 2020-09-08T11:42:29.190855 | true |", - "+-------------------------+----------------------------+-------------------------+", - ]; - assert_batches_eq!(expected, &actual); - } - - { - let ctx = SessionContext::new(); - let table_a = make_timestamp_table::()?; - let table_b = make_timestamp_table::()?; - ctx.register_table("table_a", table_a)?; - ctx.register_table("table_b", table_b)?; - - let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b order by table_a.ts desc, table_b.ts desc"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+----------------------------+---------------------+-------------------------+", - "| ts | ts | table_a.ts = table_b.ts |", - "+----------------------------+---------------------+-------------------------+", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T13:42:29 | true |", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T12:42:29 | false |", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T11:42:29 | false |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T13:42:29 | false |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T12:42:29 | true |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T11:42:29 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T13:42:29 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T12:42:29 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T11:42:29 | true |", - "+----------------------------+---------------------+-------------------------+", - ]; - assert_batches_eq!(expected, &actual); - } - - { - let ctx = SessionContext::new(); - let table_a = make_timestamp_table::()?; - let table_b = make_timestamp_table::()?; - ctx.register_table("table_a", table_a)?; - ctx.register_table("table_b", table_b)?; - - let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b order by table_a.ts desc, table_b.ts desc"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+----------------------------+-------------------------+-------------------------+", - "| ts | ts | table_a.ts = table_b.ts |", - "+----------------------------+-------------------------+-------------------------+", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T13:42:29.190 | true |", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T12:42:29.190 | false |", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T11:42:29.190 | false |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T13:42:29.190 | false |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T12:42:29.190 | true |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T11:42:29.190 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T13:42:29.190 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T12:42:29.190 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T11:42:29.190 | true |", - "+----------------------------+-------------------------+-------------------------+", - ]; - assert_batches_eq!(expected, &actual); - } - - { - let ctx = SessionContext::new(); - let table_a = make_timestamp_table::()?; - let table_b = make_timestamp_table::()?; - ctx.register_table("table_a", table_a)?; - ctx.register_table("table_b", table_b)?; - - let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b order by table_a.ts desc, table_b.ts desc"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+----------------------------+----------------------------+-------------------------+", - "| ts | ts | table_a.ts = table_b.ts |", - "+----------------------------+----------------------------+-------------------------+", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T13:42:29.190855 | true |", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T12:42:29.190855 | false |", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T11:42:29.190855 | false |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T13:42:29.190855 | false |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T12:42:29.190855 | true |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T11:42:29.190855 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T13:42:29.190855 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T12:42:29.190855 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T11:42:29.190855 | true |", - "+----------------------------+----------------------------+-------------------------+", - ]; - assert_batches_eq!(expected, &actual); - } - - { - let ctx = SessionContext::new(); - let table_a = make_timestamp_table::()?; - let table_b = make_timestamp_table::()?; - ctx.register_table("table_a", table_a)?; - ctx.register_table("table_b", table_b)?; - - let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b order by table_a.ts desc, table_b.ts desc"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+----------------------------+---------------------+-------------------------+", - "| ts | ts | table_a.ts = table_b.ts |", - "+----------------------------+---------------------+-------------------------+", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T13:42:29 | true |", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T12:42:29 | false |", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T11:42:29 | false |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T13:42:29 | false |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T12:42:29 | true |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T11:42:29 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T13:42:29 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T12:42:29 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T11:42:29 | true |", - "+----------------------------+---------------------+-------------------------+", - ]; - assert_batches_eq!(expected, &actual); - } - - { - let ctx = SessionContext::new(); - let table_a = make_timestamp_table::()?; - let table_b = make_timestamp_table::()?; - ctx.register_table("table_a", table_a)?; - ctx.register_table("table_b", table_b)?; - - let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b order by table_a.ts desc, table_b.ts desc"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+----------------------------+-------------------------+-------------------------+", - "| ts | ts | table_a.ts = table_b.ts |", - "+----------------------------+-------------------------+-------------------------+", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T13:42:29.190 | true |", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T12:42:29.190 | false |", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T11:42:29.190 | false |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T13:42:29.190 | false |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T12:42:29.190 | true |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T11:42:29.190 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T13:42:29.190 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T12:42:29.190 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T11:42:29.190 | true |", - "+----------------------------+-------------------------+-------------------------+", - ]; - assert_batches_eq!(expected, &actual); - } - - { - let ctx = SessionContext::new(); - let table_a = make_timestamp_table::()?; - let table_b = make_timestamp_table::()?; - ctx.register_table("table_a", table_a)?; - ctx.register_table("table_b", table_b)?; - - let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b order by table_a.ts desc, table_b.ts desc"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+----------------------------+----------------------------+-------------------------+", - "| ts | ts | table_a.ts = table_b.ts |", - "+----------------------------+----------------------------+-------------------------+", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T13:42:29.190855 | true |", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T12:42:29.190855 | false |", - "| 2020-09-08T13:42:29.190855 | 2020-09-08T11:42:29.190855 | false |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T13:42:29.190855 | false |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T12:42:29.190855 | true |", - "| 2020-09-08T12:42:29.190855 | 2020-09-08T11:42:29.190855 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T13:42:29.190855 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T12:42:29.190855 | false |", - "| 2020-09-08T11:42:29.190855 | 2020-09-08T11:42:29.190855 | true |", - "+----------------------------+----------------------------+-------------------------+", - ]; - assert_batches_eq!(expected, &actual); - } - - Ok(()) -} - #[tokio::test] async fn group_by_timestamp_millis() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index fb0ecd02c6b0..5882718acefd 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -36,8 +36,7 @@ use datafusion::{ assert_batches_eq, error::Result, logical_expr::{ - AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature, - StateTypeFunction, TypeSignature, Volatility, + AccumulatorFactoryFunction, AggregateUDF, Signature, TypeSignature, Volatility, }, physical_plan::Accumulator, prelude::SessionContext, @@ -46,7 +45,7 @@ use datafusion::{ use datafusion_common::{ assert_contains, cast::as_primitive_array, exec_err, DataFusionError, }; -use datafusion_expr::create_udaf; +use datafusion_expr::{create_udaf, SimpleAggregateUDF}; use datafusion_physical_expr::expressions::AvgAccumulator; /// Test to show the contents of the setup @@ -141,7 +140,7 @@ async fn test_udaf_as_window_with_frame_without_retract_batch() { let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t"; // Note if this query ever does start working let err = execute(&ctx, sql).await.unwrap_err(); - assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: AggregateUDF { name: \"time_sum\""); + assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: AggregateUDF { inner: AggregateUDF { name: \"time_sum\", signature: Signature { type_signature: Exact([Timestamp(Nanosecond, None)]), volatility: Immutable }, fun: \"\" } }(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING"); } /// Basic query for with a udaf returning a structure @@ -408,26 +407,27 @@ impl TimeSum { fn register(ctx: &mut SessionContext, test_state: Arc, name: &str) { let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None); + let input_type = vec![timestamp_type.clone()]; // Returns the same type as its input - let return_type = Arc::new(timestamp_type.clone()); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::clone(&return_type))); + let return_type = timestamp_type.clone(); - let state_type = Arc::new(vec![timestamp_type.clone()]); - let state_type: StateTypeFunction = - Arc::new(move |_| Ok(Arc::clone(&state_type))); + let state_type = vec![timestamp_type.clone()]; let volatility = Volatility::Immutable; - let signature = Signature::exact(vec![timestamp_type], volatility); - let captured_state = Arc::clone(&test_state); let accumulator: AccumulatorFactoryFunction = Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); - let time_sum = - AggregateUDF::new(name, &signature, &return_type, &accumulator, &state_type); + let time_sum = AggregateUDF::from(SimpleAggregateUDF::new( + name, + input_type, + return_type, + volatility, + accumulator, + state_type, + )); // register the selector as "time_sum" ctx.register_udaf(time_sum) @@ -510,11 +510,8 @@ impl FirstSelector { } fn register(ctx: &mut SessionContext) { - let return_type = Arc::new(Self::output_datatype()); - let state_type = Arc::new(Self::state_datatypes()); - - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); + let return_type = Self::output_datatype(); + let state_type = Self::state_datatypes(); // Possible input signatures let signatures = vec![TypeSignature::Exact(Self::input_datatypes())]; @@ -526,13 +523,13 @@ impl FirstSelector { let name = "first"; - let first = AggregateUDF::new( + let first = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( name, - &Signature::one_of(signatures, volatility), - &return_type, - &accumulator, - &state_type, - ); + Signature::one_of(signatures, volatility), + return_type, + accumulator, + state_type, + )); // register the selector as "first" ctx.register_udaf(first) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 985b0bd5bc76..4f39f2374ea9 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -291,8 +291,8 @@ async fn udaf_as_window_func() -> Result<()> { context.register_udaf(my_acc); let sql = "SELECT a, MY_ACC(b) OVER(PARTITION BY a) FROM my_table"; - let expected = r#"Projection: my_table.a, AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "" }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - WindowAggr: windowExpr=[[AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "" }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + let expected = r#"Projection: my_table.a, AggregateUDF { inner: AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "" } }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[AggregateUDF { inner: AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "" } }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] TableScan: my_table"#; let dataframe = context.sql(sql).await.unwrap(); diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 3040fbafe81a..54eab4315a97 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -471,6 +471,7 @@ impl OddCounter { } fn register(ctx: &mut SessionContext, test_state: Arc) { + #[derive(Debug, Clone)] struct SimpleWindowUDF { signature: Signature, return_type: DataType, diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 8556335b395a..5158e7177335 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -24,7 +24,69 @@ use std::{ use datafusion_common::{config::ConfigOptions, Result, ScalarValue}; -/// Configuration options for Execution context +/// Configuration options for [`SessionContext`]. +/// +/// Can be passed to [`SessionContext::new_with_config`] to customize the configuration of DataFusion. +/// +/// Options can be set using namespaces keys with `.` as the separator, where the +/// namespace determines which configuration struct the value to routed to. All +/// built-in options are under the `datafusion` namespace. +/// +/// For example, the key `datafusion.execution.batch_size` will set [ExecutionOptions::batch_size][datafusion_common::config::ExecutionOptions::batch_size], +/// because [ConfigOptions::execution] is [ExecutionOptions][datafusion_common::config::ExecutionOptions]. Similarly, the key +/// `datafusion.execution.parquet.pushdown_filters` will set [ParquetOptions::pushdown_filters][datafusion_common::config::ParquetOptions::pushdown_filters], +/// since [ExecutionOptions::parquet][datafusion_common::config::ExecutionOptions::parquet] is [ParquetOptions][datafusion_common::config::ParquetOptions]. +/// +/// Some options have convenience methods. For example [SessionConfig::with_batch_size] is +/// shorthand for setting `datafusion.execution.batch_size`. +/// +/// ``` +/// use datafusion_execution::config::SessionConfig; +/// use datafusion_common::ScalarValue; +/// +/// let config = SessionConfig::new() +/// .set("datafusion.execution.batch_size", ScalarValue::UInt64(Some(1234))) +/// .set_bool("datafusion.execution.parquet.pushdown_filters", true); +/// +/// assert_eq!(config.batch_size(), 1234); +/// assert_eq!(config.options().execution.batch_size, 1234); +/// assert_eq!(config.options().execution.parquet.pushdown_filters, true); +/// ``` +/// +/// You can also directly mutate the options via [SessionConfig::options_mut]. +/// So the following is equivalent to the above: +/// +/// ``` +/// # use datafusion_execution::config::SessionConfig; +/// # use datafusion_common::ScalarValue; +/// # +/// let mut config = SessionConfig::new(); +/// config.options_mut().execution.batch_size = 1234; +/// config.options_mut().execution.parquet.pushdown_filters = true; +/// # +/// # assert_eq!(config.batch_size(), 1234); +/// # assert_eq!(config.options().execution.batch_size, 1234); +/// # assert_eq!(config.options().execution.parquet.pushdown_filters, true); +/// ``` +/// +/// ## Built-in options +/// +/// | Namespace | Config struct | +/// | --------- | ------------- | +/// | `datafusion.catalog` | [CatalogOptions][datafusion_common::config::CatalogOptions] | +/// | `datafusion.execution` | [ExecutionOptions][datafusion_common::config::ExecutionOptions] | +/// | `datafusion.execution.aggregate` | [AggregateOptions][datafusion_common::config::AggregateOptions] | +/// | `datafusion.execution.parquet` | [ParquetOptions][datafusion_common::config::ParquetOptions] | +/// | `datafusion.optimizer` | [OptimizerOptions][datafusion_common::config::OptimizerOptions] | +/// | `datafusion.sql_parser` | [SqlParserOptions][datafusion_common::config::SqlParserOptions] | +/// | `datafusion.explain` | [ExplainOptions][datafusion_common::config::ExplainOptions] | +/// +/// ## Custom configuration +/// +/// Configuration options can be extended. See [SessionConfig::with_extension] for details. +/// +/// [`SessionContext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html +/// [`SessionContext::new_with_config`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.new_with_config #[derive(Clone, Debug)] pub struct SessionConfig { /// Configuration options @@ -62,6 +124,35 @@ impl SessionConfig { Ok(ConfigOptions::from_string_hash_map(settings)?.into()) } + /// Return a handle to the configuration options. + /// + /// Can be used to read the current configuration. + /// + /// ``` + /// use datafusion_execution::config::SessionConfig; + /// + /// let config = SessionConfig::new(); + /// assert!(config.options().execution.batch_size > 0); + /// ``` + pub fn options(&self) -> &ConfigOptions { + &self.options + } + + /// Return a mutable handle to the configuration options. + /// + /// Can be used to set configuration options. + /// + /// ``` + /// use datafusion_execution::config::SessionConfig; + /// + /// let mut config = SessionConfig::new(); + /// config.options_mut().execution.batch_size = 1024; + /// assert_eq!(config.options().execution.batch_size, 1024); + /// ``` + pub fn options_mut(&mut self) -> &mut ConfigOptions { + &mut self.options + } + /// Set a configuration option pub fn set(mut self, key: &str, value: ScalarValue) -> Self { self.options.set(key, &value.to_string()).unwrap(); @@ -346,16 +437,6 @@ impl SessionConfig { &mut self.options } - /// Return a handle to the configuration options. - pub fn options(&self) -> &ConfigOptions { - &self.options - } - - /// Return a mutable handle to the configuration options. - pub fn options_mut(&mut self) -> &mut ConfigOptions { - &mut self.options - } - /// Add extensions. /// /// Extensions can be used to attach extra data to the session config -- e.g. tracing information or caches. diff --git a/datafusion/execution/src/object_store.rs b/datafusion/execution/src/object_store.rs index 5a1cdb769098..7626f8bef162 100644 --- a/datafusion/execution/src/object_store.rs +++ b/datafusion/execution/src/object_store.rs @@ -21,6 +21,7 @@ use dashmap::DashMap; use datafusion_common::{exec_err, DataFusionError, Result}; +#[cfg(not(target_arch = "wasm32"))] use object_store::local::LocalFileSystem; use object_store::ObjectStore; use std::sync::Arc; @@ -169,16 +170,24 @@ impl Default for DefaultObjectStoreRegistry { impl DefaultObjectStoreRegistry { /// This will register [`LocalFileSystem`] to handle `file://` paths + #[cfg(not(target_arch = "wasm32"))] pub fn new() -> Self { let object_stores: DashMap> = DashMap::new(); object_stores.insert("file://".to_string(), Arc::new(LocalFileSystem::new())); Self { object_stores } } + + /// Default without any backend registered. + #[cfg(target_arch = "wasm32")] + pub fn new() -> Self { + Self::default() + } } /// /// Stores are registered based on the scheme, host and port of the provided URL -/// with a [`LocalFileSystem::new`] automatically registered for `file://` +/// with a [`LocalFileSystem::new`] automatically registered for `file://` (if the +/// target arch is not `wasm32`). /// /// For example: /// diff --git a/datafusion/expr/src/array_expressions.rs b/datafusion/expr/src/array_expressions.rs deleted file mode 100644 index 6469437866f4..000000000000 --- a/datafusion/expr/src/array_expressions.rs +++ /dev/null @@ -1,37 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::datatypes::DataType; - -/// Currently supported types by the array function. -/// The order of these types correspond to the order on which coercion applies -/// This should thus be from least informative to most informative -pub static SUPPORTED_ARRAY_TYPES: &[DataType] = &[ - DataType::Boolean, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - DataType::Utf8, - DataType::LargeUtf8, -]; diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index e642dae06e4f..6f64642f60d9 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -187,6 +187,8 @@ pub enum BuiltinScalarFunction { ArrayExcept, /// cardinality Cardinality, + /// array_resize + ArrayResize, /// construct an array from columns MakeArray, /// Flatten @@ -430,6 +432,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayToString => Volatility::Immutable, BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable, BuiltinScalarFunction::ArrayUnion => Volatility::Immutable, + BuiltinScalarFunction::ArrayResize => Volatility::Immutable, BuiltinScalarFunction::Range => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, BuiltinScalarFunction::MakeArray => Volatility::Immutable, @@ -617,6 +620,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceN => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayResize => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), BuiltinScalarFunction::ArrayIntersect => { match (input_expr_types[0].clone(), input_expr_types[1].clone()) { @@ -980,6 +984,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()), BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayResize => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::Range => Signature::one_of( vec![ Exact(vec![Int64]), @@ -1647,6 +1655,7 @@ impl BuiltinScalarFunction { ], BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"], BuiltinScalarFunction::Cardinality => &["cardinality"], + BuiltinScalarFunction::ArrayResize => &["array_resize", "list_resize"], BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], BuiltinScalarFunction::ArrayIntersect => { &["array_intersect", "list_intersect"] diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ebf4d3143c12..40d40692e593 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Expr module contains core type definition for `Expr`. +//! Logical Expressions: [`Expr`] use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; @@ -1948,6 +1948,7 @@ mod test { ); // UDF + #[derive(Debug)] struct TestScalarUDF { signature: Signature, } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index f76fb17b38bb..834420e413b0 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -25,13 +25,14 @@ use crate::function::PartitionEvaluatorFactory; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, - BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, - ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, + BuiltinScalarFunction, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, + ScalarUDF, Signature, Volatility, }; -use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; +use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; use std::any::Any; +use std::fmt::Debug; use std::ops::Not; use std::sync::Arc; @@ -747,6 +748,14 @@ scalar_expr!( array, "returns the total number of elements in the array." ); + +scalar_expr!( + ArrayResize, + array_resize, + array size value, + "returns an array with the specified size filled with the given value." +); + nary_scalar_expr!( MakeArray, array, @@ -983,6 +992,16 @@ pub struct SimpleScalarUDF { fun: ScalarFunctionImplementation, } +impl Debug for SimpleScalarUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + impl SimpleScalarUDF { /// Create a new `SimpleScalarUDF` from a name, input types, return type and /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility @@ -1036,15 +1055,102 @@ pub fn create_udaf( accumulator: AccumulatorFactoryFunction, state_type: Arc>, ) -> AggregateUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); - AggregateUDF::new( + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); + AggregateUDF::from(SimpleAggregateUDF::new( name, - &Signature::exact(input_type, volatility), - &return_type, - &accumulator, - &state_type, - ) + input_type, + return_type, + volatility, + accumulator, + state_type, + )) +} + +/// Implements [`AggregateUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleAggregateUDF { + name: String, + signature: Signature, + return_type: DataType, + accumulator: AccumulatorFactoryFunction, + state_type: Vec, +} + +impl Debug for SimpleAggregateUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl SimpleAggregateUDF { + /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and + /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_type: Vec, + return_type: DataType, + volatility: Volatility, + accumulator: AccumulatorFactoryFunction, + state_type: Vec, + ) -> Self { + let name = name.into(); + let signature = Signature::exact(input_type, volatility); + Self { + name, + signature, + return_type, + accumulator, + state_type, + } + } + + pub fn new_with_signature( + name: impl Into, + signature: Signature, + return_type: DataType, + accumulator: AccumulatorFactoryFunction, + state_type: Vec, + ) -> Self { + let name = name.into(); + Self { + name, + signature, + return_type, + accumulator, + state_type, + } + } +} + +impl AggregateUDFImpl for SimpleAggregateUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn accumulator(&self, arg: &DataType) -> Result> { + (self.accumulator)(arg) + } + + fn state_type(&self, _return_type: &DataType) -> Result> { + Ok(self.state_type.clone()) + } } /// Creates a new UDWF with a specific signature, state type and return type. @@ -1078,6 +1184,17 @@ pub struct SimpleWindowUDF { partition_evaluator_factory: PartitionEvaluatorFactory, } +impl Debug for SimpleWindowUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("WindowUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("return_type", &"") + .field("partition_evaluator_factory", &"") + .finish() + } +} + impl SimpleWindowUDF { /// Create a new `SimpleWindowUDF` from a name, input types, return type and /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 077681d21725..21647f384159 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -40,7 +40,6 @@ mod udf; mod udwf; pub mod aggregate_function; -pub mod array_expressions; pub mod conditional_expressions; pub mod expr; pub mod expr_fn; @@ -80,7 +79,7 @@ pub use signature::{ FuncMonotonicity, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::AggregateUDF; +pub use udaf::{AggregateUDF, AggregateUDFImpl}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index a684f3e97485..847fbbbf61c7 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1845,13 +1845,16 @@ mod tests { .project(vec![col("id"), col("first_name").alias("id")]); match plan { - Err(DataFusionError::SchemaError(SchemaError::AmbiguousReference { - field: - Column { - relation: Some(OwnedTableReference::Bare { table }), - name, - }, - })) => { + Err(DataFusionError::SchemaError( + SchemaError::AmbiguousReference { + field: + Column { + relation: Some(OwnedTableReference::Bare { table }), + name, + }, + }, + _, + )) => { assert_eq!("employee_csv", table); assert_eq!("id", &name); Ok(()) @@ -1872,13 +1875,16 @@ mod tests { .aggregate(vec![col("state")], vec![sum(col("salary")).alias("state")]); match plan { - Err(DataFusionError::SchemaError(SchemaError::AmbiguousReference { - field: - Column { - relation: Some(OwnedTableReference::Bare { table }), - name, - }, - })) => { + Err(DataFusionError::SchemaError( + SchemaError::AmbiguousReference { + field: + Column { + relation: Some(OwnedTableReference::Bare { table }), + name, + }, + }, + _, + )) => { assert_eq!("employee_csv", table); assert_eq!("state", &name); Ok(()) diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 1b62c1bc05c1..6bacc1870079 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -667,8 +667,6 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_concat_internal_coercion(from_type, &LargeUtf8) } - // TODO: cast between array elements (#6558) - (List(_), from_type) | (from_type, List(_)) => Some(from_type.to_owned()), _ => None, }) } diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index cfbca4ab1337..4983f6247d24 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -23,6 +23,7 @@ use crate::{ }; use arrow::datatypes::DataType; use datafusion_common::Result; +use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -42,36 +43,29 @@ use std::sync::Arc; /// /// For more information, please see [the examples]. /// +/// 1. For simple (less performant) use cases, use [`create_udaf`] and [`simple_udaf.rs`]. +/// +/// 2. For advanced use cases, use [`AggregateUDFImpl`] and [`advanced_udaf.rs`]. +/// +/// # API Note +/// This is a separate struct from `AggregateUDFImpl` to maintain backwards +/// compatibility with the older API. +/// /// [the examples]: https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples#single-process /// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function /// [`Accumulator`]: crate::Accumulator -#[derive(Clone)] -pub struct AggregateUDF { - /// name - name: String, - /// Signature (input arguments) - signature: Signature, - /// Return type - return_type: ReturnTypeFunction, - /// actual implementation - accumulator: AccumulatorFactoryFunction, - /// the accumulator's state's description as a function of the return type - state_type: StateTypeFunction, -} +/// [`create_udaf`]: crate::expr_fn::create_udaf +/// [`simple_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs +/// [`advanced_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs -impl Debug for AggregateUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("AggregateUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } +#[derive(Debug, Clone)] +pub struct AggregateUDF { + inner: Arc, } impl PartialEq for AggregateUDF { fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature + self.name() == other.name() && self.signature() == other.signature() } } @@ -79,13 +73,17 @@ impl Eq for AggregateUDF {} impl std::hash::Hash for AggregateUDF { fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); + self.name().hash(state); + self.signature().hash(state); } } impl AggregateUDF { /// Create a new AggregateUDF + /// + /// See [`AggregateUDFImpl`] for a more convenient way to create a + /// `AggregateUDF` using trait objects + #[deprecated(since = "34.0.0", note = "please implement AggregateUDFImpl instead")] pub fn new( name: &str, signature: &Signature, @@ -93,15 +91,32 @@ impl AggregateUDF { accumulator: &AccumulatorFactoryFunction, state_type: &StateTypeFunction, ) -> Self { - Self { + Self::new_from_impl(AggregateUDFLegacyWrapper { name: name.to_owned(), signature: signature.clone(), return_type: return_type.clone(), accumulator: accumulator.clone(), state_type: state_type.clone(), + }) + } + + /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object + /// + /// Note this is the same as using the `From` impl (`AggregateUDF::from`) + pub fn new_from_impl(fun: F) -> AggregateUDF + where + F: AggregateUDFImpl + 'static, + { + Self { + inner: Arc::new(fun), } } + /// Return the underlying [`AggregateUDFImpl`] trait object for this function + pub fn inner(&self) -> Arc { + self.inner.clone() + } + /// creates an [`Expr`] that calls the aggregate function. /// /// This utility allows using the UDAF without requiring access to @@ -117,33 +132,176 @@ impl AggregateUDF { } /// Returns this function's name + /// + /// See [`AggregateUDFImpl::name`] for more details. pub fn name(&self) -> &str { - &self.name + self.inner.name() } /// Returns this function's signature (what input types are accepted) + /// + /// See [`AggregateUDFImpl::signature`] for more details. pub fn signature(&self) -> &Signature { - &self.signature + self.inner.signature() } /// Return the type of the function given its input types + /// + /// See [`AggregateUDFImpl::return_type`] for more details. pub fn return_type(&self, args: &[DataType]) -> Result { - // Old API returns an Arc of the datatype for some reason - let res = (self.return_type)(args)?; - Ok(res.as_ref().clone()) + self.inner.return_type(args) } /// Return an accumualator the given aggregate, given /// its return datatype. pub fn accumulator(&self, return_type: &DataType) -> Result> { - (self.accumulator)(return_type) + self.inner.accumulator(return_type) } /// Return the type of the intermediate state used by this aggregator, given /// its return datatype. Supports multi-phase aggregations pub fn state_type(&self, return_type: &DataType) -> Result> { - // old API returns an Arc for some reason, try and unwrap it here + self.inner.state_type(return_type) + } +} + +impl From for AggregateUDF +where + F: AggregateUDFImpl + Send + Sync + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// Trait for implementing [`AggregateUDF`]. +/// +/// This trait exposes the full API for implementing user defined aggregate functions and +/// can be used to implement any function. +/// +/// See [`advanced_udaf.rs`] for a full example with complete implementation and +/// [`AggregateUDF`] for other available options. +/// +/// +/// [`advanced_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator}; +/// #[derive(Debug, Clone)] +/// struct GeoMeanUdf { +/// signature: Signature +/// }; +/// +/// impl GeoMeanUdf { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable) +/// } +/// } +/// } +/// +/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf +/// impl AggregateUDFImpl for GeoMeanUdf { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "geo_mean" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Float64)) { +/// return plan_err!("add_one only accepts Float64 arguments"); +/// } +/// Ok(DataType::Float64) +/// } +/// // This is the accumulator factory; DataFusion uses it to create new accumulators. +/// fn accumulator(&self, _arg: &DataType) -> Result> { unimplemented!() } +/// fn state_type(&self, _return_type: &DataType) -> Result> { +/// Ok(vec![DataType::Float64, DataType::UInt32]) +/// } +/// } +/// +/// // Create a new AggregateUDF from the implementation +/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new()); +/// +/// // Call the function `geo_mean(col)` +/// let expr = geometric_mean.call(vec![col("a")]); +/// ``` +pub trait AggregateUDFImpl: Debug + Send + Sync { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns the function's [`Signature`] for information about what input + /// types are accepted and the function's Volatility. + fn signature(&self) -> &Signature; + + /// What [`DataType`] will be returned by this function, given the types of + /// the arguments + fn return_type(&self, arg_types: &[DataType]) -> Result; + + /// Return a new [`Accumulator`] that aggregates values for a specific + /// group during query execution. + fn accumulator(&self, arg: &DataType) -> Result>; + + /// Return the type used to serialize the [`Accumulator`]'s intermediate state. + /// See [`Accumulator::state()`] for more details + fn state_type(&self, return_type: &DataType) -> Result>; +} + +/// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers +/// of the older API +pub struct AggregateUDFLegacyWrapper { + /// name + name: String, + /// Signature (input arguments) + signature: Signature, + /// Return type + return_type: ReturnTypeFunction, + /// actual implementation + accumulator: AccumulatorFactoryFunction, + /// the accumulator's state's description as a function of the return type + state_type: StateTypeFunction, +} + +impl Debug for AggregateUDFLegacyWrapper { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl AggregateUDFImpl for AggregateUDFLegacyWrapper { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(arg_types)?; + Ok(res.as_ref().clone()) + } + + fn accumulator(&self, arg: &DataType) -> Result> { + (self.accumulator)(arg) + } + + fn state_type(&self, return_type: &DataType) -> Result> { let res = (self.state_type)(return_type)?; - Ok(Arc::try_unwrap(res).unwrap_or_else(|res| res.as_ref().clone())) + Ok(res.as_ref().clone()) } } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 2ec80a4a9ea1..3017e1ec0271 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -18,7 +18,8 @@ //! [`ScalarUDF`]: Scalar User Defined Functions use crate::{ - ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature, + ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction, + ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; use datafusion_common::Result; @@ -35,48 +36,26 @@ use std::sync::Arc; /// functions you supply such name, type signature, return type, and actual /// implementation. /// -/// /// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`]. /// /// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`]. /// +/// # API Note +/// +/// This is a separate struct from `ScalarUDFImpl` to maintain backwards +/// compatibility with the older API. +/// /// [`create_udf`]: crate::expr_fn::create_udf /// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs /// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct ScalarUDF { - /// The name of the function - name: String, - /// The signature (the types of arguments that are supported) - signature: Signature, - /// Function that returns the return type given the argument types - return_type: ReturnTypeFunction, - /// actual implementation - /// - /// The fn param is the wrapped function but be aware that the function will - /// be passed with the slice / vec of columnar values (either scalar or array) - /// with the exception of zero param function, where a singular element vec - /// will be passed. In that case the single element is a null array to indicate - /// the batch's row count (so that the generative zero-argument function can know - /// the result array size). - fun: ScalarFunctionImplementation, - /// Optional aliases for the function. This list should NOT include the value of `name` as well - aliases: Vec, -} - -impl Debug for ScalarUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("ScalarUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } + inner: Arc, } impl PartialEq for ScalarUDF { fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature + self.name() == other.name() && self.signature() == other.signature() } } @@ -84,8 +63,8 @@ impl Eq for ScalarUDF {} impl std::hash::Hash for ScalarUDF { fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); + self.name().hash(state); + self.signature().hash(state); } } @@ -101,13 +80,12 @@ impl ScalarUDF { return_type: &ReturnTypeFunction, fun: &ScalarFunctionImplementation, ) -> Self { - Self { + Self::new_from_impl(ScalarUdfLegacyWrapper { name: name.to_owned(), signature: signature.clone(), return_type: return_type.clone(), fun: fun.clone(), - aliases: vec![], - } + }) } /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object @@ -115,37 +93,24 @@ impl ScalarUDF { /// Note this is the same as using the `From` impl (`ScalarUDF::from`) pub fn new_from_impl(fun: F) -> ScalarUDF where - F: ScalarUDFImpl + Send + Sync + 'static, + F: ScalarUDFImpl + 'static, { - // TODO change the internal implementation to use the trait object - let arc_fun = Arc::new(fun); - let captured_self = arc_fun.clone(); - let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { - let return_type = captured_self.return_type(arg_types)?; - Ok(Arc::new(return_type)) - }); - - let captured_self = arc_fun.clone(); - let func: ScalarFunctionImplementation = - Arc::new(move |args| captured_self.invoke(args)); - Self { - name: arc_fun.name().to_string(), - signature: arc_fun.signature().clone(), - return_type: return_type.clone(), - fun: func, - aliases: arc_fun.aliases().to_vec(), + inner: Arc::new(fun), } } - /// Adds additional names that can be used to invoke this function, in addition to `name` - pub fn with_aliases( - mut self, - aliases: impl IntoIterator, - ) -> Self { - self.aliases - .extend(aliases.into_iter().map(|s| s.to_string())); - self + /// Return the underlying [`ScalarUDFImpl`] trait object for this function + pub fn inner(&self) -> Arc { + self.inner.clone() + } + + /// Adds additional names that can be used to invoke this function, in + /// addition to `name` + /// + /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly. + pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { + Self::new_from_impl(AliasedScalarUDFImpl::new(self, aliases)) } /// Returns a [`Expr`] logical expression to call this UDF with specified @@ -159,31 +124,53 @@ impl ScalarUDF { )) } - /// Returns this function's name + /// Returns this function's name. + /// + /// See [`ScalarUDFImpl::name`] for more details. pub fn name(&self) -> &str { - &self.name + self.inner.name() } - /// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details + /// Returns the aliases for this function. + /// + /// See [`ScalarUDF::with_aliases`] for more details pub fn aliases(&self) -> &[String] { - &self.aliases + self.inner.aliases() } - /// Returns this function's [`Signature`] (what input types are accepted) + /// Returns this function's [`Signature`] (what input types are accepted). + /// + /// See [`ScalarUDFImpl::signature`] for more details. pub fn signature(&self) -> &Signature { - &self.signature + self.inner.signature() } - /// The datatype this function returns given the input argument input types + /// The datatype this function returns given the input argument input types. + /// + /// See [`ScalarUDFImpl::return_type`] for more details. pub fn return_type(&self, args: &[DataType]) -> Result { - // Old API returns an Arc of the datatype for some reason - let res = (self.return_type)(args)?; - Ok(res.as_ref().clone()) + self.inner.return_type(args) } - /// Return an [`Arc`] to the function implementation + /// Invoke the function on `args`, returning the appropriate result. + /// + /// See [`ScalarUDFImpl::invoke`] for more details. + pub fn invoke(&self, args: &[ColumnarValue]) -> Result { + self.inner.invoke(args) + } + + /// Returns a `ScalarFunctionImplementation` that can invoke the function + /// during execution pub fn fun(&self) -> ScalarFunctionImplementation { - self.fun.clone() + let captured = self.inner.clone(); + Arc::new(move |args| captured.invoke(args)) + } + + /// This function specifies monotonicity behaviors for User defined scalar functions. + /// + /// See [`ScalarUDFImpl::monotonicity`] for more details. + pub fn monotonicity(&self) -> Result> { + self.inner.monotonicity() } } @@ -213,6 +200,7 @@ where /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; /// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; +/// #[derive(Debug)] /// struct AddOne { /// signature: Signature /// }; @@ -246,7 +234,7 @@ where /// // Call the function `add_one(col)` /// let expr = add_one.call(vec![col("a")]); /// ``` -pub trait ScalarUDFImpl { +pub trait ScalarUDFImpl: Debug + Send + Sync { /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; @@ -291,4 +279,112 @@ pub trait ScalarUDFImpl { fn aliases(&self) -> &[String] { &[] } + + /// This function specifies monotonicity behaviors for User defined scalar functions. + fn monotonicity(&self) -> Result> { + Ok(None) + } +} + +/// ScalarUDF that adds an alias to the underlying function. It is better to +/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible. +#[derive(Debug)] +struct AliasedScalarUDFImpl { + inner: ScalarUDF, + aliases: Vec, +} + +impl AliasedScalarUDFImpl { + pub fn new( + inner: ScalarUDF, + new_aliases: impl IntoIterator, + ) -> Self { + let mut aliases = inner.aliases().to_vec(); + aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); + + Self { inner, aliases } + } +} + +impl ScalarUDFImpl for AliasedScalarUDFImpl { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + self.inner.name() + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner.return_type(arg_types) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + self.inner.invoke(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers +/// of the older API (see +/// for more details) +struct ScalarUdfLegacyWrapper { + /// The name of the function + name: String, + /// The signature (the types of arguments that are supported) + signature: Signature, + /// Function that returns the return type given the argument types + return_type: ReturnTypeFunction, + /// actual implementation + /// + /// The fn param is the wrapped function but be aware that the function will + /// be passed with the slice / vec of columnar values (either scalar or array) + /// with the exception of zero param function, where a singular element vec + /// will be passed. In that case the single element is a null array to indicate + /// the batch's row count (so that the generative zero-argument function can know + /// the result array size). + fun: ScalarFunctionImplementation, +} + +impl Debug for ScalarUdfLegacyWrapper { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl ScalarUDFImpl for ScalarUdfLegacyWrapper { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(arg_types)?; + Ok(res.as_ref().clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + (self.fun)(args) + } + + fn aliases(&self) -> &[String] { + &[] + } } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 800386bfc77b..9b8f94f4b020 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -34,40 +34,33 @@ use std::{ /// /// See the documetnation on [`PartitionEvaluator`] for more details /// +/// 1. For simple (less performant) use cases, use [`create_udwf`] and [`simple_udwf.rs`]. +/// +/// 2. For advanced use cases, use [`WindowUDFImpl`] and [`advanced_udwf.rs`]. +/// +/// # API Note +/// This is a separate struct from `WindowUDFImpl` to maintain backwards +/// compatibility with the older API. +/// /// [`PartitionEvaluator`]: crate::PartitionEvaluator -#[derive(Clone)] +/// [`create_udwf`]: crate::expr_fn::create_udwf +/// [`simple_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs +/// [`advanced_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs +#[derive(Debug, Clone)] pub struct WindowUDF { - /// name - name: String, - /// signature - signature: Signature, - /// Return type - return_type: ReturnTypeFunction, - /// Return the partition evaluator - partition_evaluator_factory: PartitionEvaluatorFactory, -} - -impl Debug for WindowUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("WindowUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("return_type", &"") - .field("partition_evaluator_factory", &"") - .finish_non_exhaustive() - } + inner: Arc, } /// Defines how the WindowUDF is shown to users impl Display for WindowUDF { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "{}", self.name) + write!(f, "{}", self.name()) } } impl PartialEq for WindowUDF { fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature + self.name() == other.name() && self.signature() == other.signature() } } @@ -75,8 +68,8 @@ impl Eq for WindowUDF {} impl std::hash::Hash for WindowUDF { fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); + self.name().hash(state); + self.signature().hash(state); } } @@ -92,12 +85,12 @@ impl WindowUDF { return_type: &ReturnTypeFunction, partition_evaluator_factory: &PartitionEvaluatorFactory, ) -> Self { - Self { - name: name.to_string(), + Self::new_from_impl(WindowUDFLegacyWrapper { + name: name.to_owned(), signature: signature.clone(), return_type: return_type.clone(), partition_evaluator_factory: partition_evaluator_factory.clone(), - } + }) } /// Create a new `WindowUDF` from a `[WindowUDFImpl]` trait object @@ -105,27 +98,18 @@ impl WindowUDF { /// Note this is the same as using the `From` impl (`WindowUDF::from`) pub fn new_from_impl(fun: F) -> WindowUDF where - F: WindowUDFImpl + Send + Sync + 'static, + F: WindowUDFImpl + 'static, { - let arc_fun = Arc::new(fun); - let captured_self = arc_fun.clone(); - let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { - let return_type = captured_self.return_type(arg_types)?; - Ok(Arc::new(return_type)) - }); - - let captured_self = arc_fun.clone(); - let partition_evaluator_factory: PartitionEvaluatorFactory = - Arc::new(move || captured_self.partition_evaluator()); - Self { - name: arc_fun.name().to_string(), - signature: arc_fun.signature().clone(), - return_type: return_type.clone(), - partition_evaluator_factory, + inner: Arc::new(fun), } } + /// Return the underlying [`WindowUDFImpl`] trait object for this function + pub fn inner(&self) -> Arc { + self.inner.clone() + } + /// creates a [`Expr`] that calls the window function given /// the `partition_by`, `order_by`, and `window_frame` definition /// @@ -150,25 +134,29 @@ impl WindowUDF { } /// Returns this function's name + /// + /// See [`WindowUDFImpl::name`] for more details. pub fn name(&self) -> &str { - &self.name + self.inner.name() } /// Returns this function's signature (what input types are accepted) + /// + /// See [`WindowUDFImpl::signature`] for more details. pub fn signature(&self) -> &Signature { - &self.signature + self.inner.signature() } /// Return the type of the function given its input types + /// + /// See [`WindowUDFImpl::return_type`] for more details. pub fn return_type(&self, args: &[DataType]) -> Result { - // Old API returns an Arc of the datatype for some reason - let res = (self.return_type)(args)?; - Ok(res.as_ref().clone()) + self.inner.return_type(args) } /// Return a `PartitionEvaluator` for evaluating this window function pub fn partition_evaluator_factory(&self) -> Result> { - (self.partition_evaluator_factory)() + self.inner.partition_evaluator() } } @@ -198,6 +186,7 @@ where /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame}; /// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; +/// #[derive(Debug, Clone)] /// struct SmoothIt { /// signature: Signature /// }; @@ -236,7 +225,7 @@ where /// WindowFrame::new(false), /// ); /// ``` -pub trait WindowUDFImpl { +pub trait WindowUDFImpl: Debug + Send + Sync { /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; @@ -254,3 +243,52 @@ pub trait WindowUDFImpl { /// Invoke the function, returning the [`PartitionEvaluator`] instance fn partition_evaluator(&self) -> Result>; } + +/// Implementation of [`WindowUDFImpl`] that wraps the function style pointers +/// of the older API (see +/// for more details) +pub struct WindowUDFLegacyWrapper { + /// name + name: String, + /// signature + signature: Signature, + /// Return type + return_type: ReturnTypeFunction, + /// Return the partition evaluator + partition_evaluator_factory: PartitionEvaluatorFactory, +} + +impl Debug for WindowUDFLegacyWrapper { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("WindowUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("return_type", &"") + .field("partition_evaluator_factory", &"") + .finish_non_exhaustive() + } +} + +impl WindowUDFImpl for WindowUDFLegacyWrapper { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(arg_types)?; + Ok(res.as_ref().clone()) + } + + fn partition_evaluator(&self) -> Result> { + (self.partition_evaluator_factory)() + } +} diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 14d5ddf47378..9d47299a5616 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -17,6 +17,7 @@ pub mod count_wildcard_rule; pub mod inline_table_scan; +pub mod rewrite_expr; pub mod subquery; pub mod type_coercion; @@ -37,6 +38,8 @@ use log::debug; use std::sync::Arc; use std::time::Instant; +use self::rewrite_expr::OperatorToFunction; + /// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make /// the plan valid prior to the rest of the DataFusion optimization process. /// @@ -72,6 +75,9 @@ impl Analyzer { pub fn new() -> Self { let rules: Vec> = vec![ Arc::new(InlineTableScan::new()), + // OperatorToFunction should be run before TypeCoercion, since it rewrite based on the argument types (List or Scalar), + // and TypeCoercion may cast the argument types from Scalar to List. + Arc::new(OperatorToFunction::new()), Arc::new(TypeCoercion::new()), Arc::new(CountWildcardRule::new()), ]; diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs new file mode 100644 index 000000000000..8f1c844ed062 --- /dev/null +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -0,0 +1,321 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Analyzer rule for to replace operators with function calls (e.g `||` to array_concat`) + +use std::sync::Arc; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::utils::list_ndims; +use datafusion_common::DFSchema; +use datafusion_common::DFSchemaRef; +use datafusion_common::Result; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr_rewriter::rewrite_preserving_name; +use datafusion_expr::utils::merge_schema; +use datafusion_expr::BuiltinScalarFunction; +use datafusion_expr::Operator; +use datafusion_expr::ScalarFunctionDefinition; +use datafusion_expr::{BinaryExpr, Expr, LogicalPlan}; + +use super::AnalyzerRule; + +#[derive(Default)] +pub struct OperatorToFunction {} + +impl OperatorToFunction { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for OperatorToFunction { + fn name(&self) -> &str { + "operator_to_function" + } + + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + analyze_internal(&plan) + } +} + +fn analyze_internal(plan: &LogicalPlan) -> Result { + // optimize child plans first + let new_inputs = plan + .inputs() + .iter() + .map(|p| analyze_internal(p)) + .collect::>>()?; + + // get schema representing all available input fields. This is used for data type + // resolution only, so order does not matter here + let mut schema = merge_schema(new_inputs.iter().collect()); + + if let LogicalPlan::TableScan(ts) = plan { + let source_schema = + DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?; + schema.merge(&source_schema); + } + + let mut expr_rewrite = OperatorToFunctionRewriter { + schema: Arc::new(schema), + }; + + let new_expr = plan + .expressions() + .into_iter() + .map(|expr| { + // ensure names don't change: + // https://github.com/apache/arrow-datafusion/issues/3555 + rewrite_preserving_name(expr, &mut expr_rewrite) + }) + .collect::>>()?; + + plan.with_new_exprs(new_expr, &new_inputs) +} + +pub(crate) struct OperatorToFunctionRewriter { + pub(crate) schema: DFSchemaRef, +} + +impl TreeNodeRewriter for OperatorToFunctionRewriter { + type N = Expr; + + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::BinaryExpr(BinaryExpr { + ref left, + op, + ref right, + }) => { + if let Some(fun) = rewrite_array_concat_operator_to_func_for_column( + left.as_ref(), + op, + right.as_ref(), + self.schema.as_ref(), + )? + .or_else(|| { + rewrite_array_concat_operator_to_func( + left.as_ref(), + op, + right.as_ref(), + ) + }) { + // Convert &Box -> Expr + let left = (**left).clone(); + let right = (**right).clone(); + return Ok(Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args: vec![left, right], + })); + } + + Ok(expr) + } + _ => Ok(expr), + } + } +} + +/// Summary of the logic below: +/// +/// 1) array || array -> array concat +/// +/// 2) array || scalar -> array append +/// +/// 3) scalar || array -> array prepend +/// +/// 4) (arry concat, array append, array prepend) || array -> array concat +/// +/// 5) (arry concat, array append, array prepend) || scalar -> array append +fn rewrite_array_concat_operator_to_func( + left: &Expr, + op: Operator, + right: &Expr, +) -> Option { + // Convert `Array StringConcat Array` to ScalarFunction::ArrayConcat + + if op != Operator::StringConcat { + return None; + } + + match (left, right) { + // Chain concat operator (a || b) || array, + // (arry concat, array append, array prepend) || array -> array concat + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayConcat), + // Chain concat operator (a || b) || scalar, + // (arry concat, array append, array prepend) || scalar -> array append + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + _scalar, + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + _scalar, + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + _scalar, + ) => Some(BuiltinScalarFunction::ArrayAppend), + // array || array -> array concat + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayConcat), + // array || scalar -> array append + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _left_args, + }), + _right_scalar, + ) => Some(BuiltinScalarFunction::ArrayAppend), + // scalar || array -> array prepend + ( + _left_scalar, + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayPrepend), + + _ => None, + } +} + +/// Summary of the logic below: +/// +/// 1) (arry concat, array append, array prepend) || column -> (array append, array concat) +/// +/// 2) column1 || column2 -> (array prepend, array append, array concat) +fn rewrite_array_concat_operator_to_func_for_column( + left: &Expr, + op: Operator, + right: &Expr, + schema: &DFSchema, +) -> Result> { + if op != Operator::StringConcat { + return Ok(None); + } + + match (left, right) { + // Column cases: + // 1) array_prepend/append/concat || column + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + Expr::Column(c), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + Expr::Column(c), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + Expr::Column(c), + ) => { + let d = schema.field_from_column(c)?.data_type(); + let ndim = list_ndims(d); + match ndim { + 0 => Ok(Some(BuiltinScalarFunction::ArrayAppend)), + _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)), + } + } + // 2) select column1 || column2 + (Expr::Column(c1), Expr::Column(c2)) => { + let d1 = schema.field_from_column(c1)?.data_type(); + let d2 = schema.field_from_column(c2)?.data_type(); + let ndim1 = list_ndims(d1); + let ndim2 = list_ndims(d2); + match (ndim1, ndim2) { + (0, _) => Ok(Some(BuiltinScalarFunction::ArrayPrepend)), + (_, 0) => Ok(Some(BuiltinScalarFunction::ArrayAppend)), + _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)), + } + } + _ => Ok(None), + } +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4d54dad99670..3821279fed0f 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -751,13 +751,13 @@ mod test { use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, - ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, StateTypeFunction, - Subquery, + ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, + SimpleAggregateUDF, Subquery, }; use datafusion_expr::{ lit, logical_plan::{EmptyRelation, Projection}, - Expr, LogicalPlan, ReturnTypeFunction, ScalarUDF, Signature, Volatility, + Expr, LogicalPlan, ScalarUDF, Signature, Volatility, }; use datafusion_physical_expr::expressions::AvgAccumulator; @@ -811,6 +811,7 @@ mod test { static TEST_SIGNATURE: OnceLock = OnceLock::new(); + #[derive(Debug, Clone, Default)] struct TestScalarUDF {} impl ScalarUDFImpl for TestScalarUDF { fn as_any(&self) -> &dyn Any { @@ -902,19 +903,17 @@ mod test { #[test] fn agg_udaf_invalid_input() -> Result<()> { let empty = empty(); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Float64))); - let state_type: StateTypeFunction = - Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64]))); + let return_type = DataType::Float64; + let state_type = vec![DataType::UInt64, DataType::Float64]; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::::default())); - let my_avg = AggregateUDF::new( + let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( "MY_AVG", - &Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), - &return_type, - &accumulator, - &state_type, - ); + Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), + return_type, + accumulator, + state_type, + )); let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit("10")], diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 1e089257c61a..000329d0d078 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -780,8 +780,8 @@ mod test { avg, col, lit, logical_plan::builder::LogicalPlanBuilder, sum, }; use datafusion_expr::{ - grouping_set, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, - Signature, StateTypeFunction, Volatility, + grouping_set, AccumulatorFactoryFunction, AggregateUDF, Signature, + SimpleAggregateUDF, Volatility, }; use crate::optimizer::OptimizerContext; @@ -901,21 +901,18 @@ mod test { fn aggregate() -> Result<()> { let table_scan = test_table_scan()?; - let return_type: ReturnTypeFunction = Arc::new(|inputs| { - assert_eq!(inputs, &[DataType::UInt32]); - Ok(Arc::new(DataType::UInt32)) - }); + let return_type = DataType::UInt32; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); - let state_type: StateTypeFunction = Arc::new(|_| unimplemented!()); + let state_type = vec![DataType::UInt32]; let udf_agg = |inner: Expr| { Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( - Arc::new(AggregateUDF::new( + Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature( "my_agg", - &Signature::exact(vec![DataType::UInt32], Volatility::Stable), - &return_type, - &accumulator, - &state_type, - )), + Signature::exact(vec![DataType::UInt32], Volatility::Stable), + return_type.clone(), + accumulator.clone(), + state_type.clone(), + ))), vec![inner], false, None, diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 891a909a3378..d9c45510972c 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -36,9 +36,10 @@ use datafusion_common::{ use datafusion_expr::expr::{Alias, ScalarFunction, ScalarFunctionDefinition}; use datafusion_expr::{ logical_plan::LogicalPlan, projection_schema, Aggregate, BinaryExpr, Cast, Distinct, - Expr, GroupingSet, Projection, TableScan, Window, + Expr, Projection, TableScan, Window, }; +use datafusion_expr::utils::inspect_expr_pre; use hashbrown::HashMap; use itertools::{izip, Itertools}; @@ -531,7 +532,7 @@ macro_rules! rewrite_expr_with_check { /// /// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result. /// - `Ok(None)`: Signals that `expr` can not be rewritten. -/// - `Err(error)`: An error occured during the function call. +/// - `Err(error)`: An error occurred during the function call. fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { let result = match expr { Expr::Column(col) => { @@ -574,23 +575,7 @@ fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { Ok(Some(result)) } -/// Retrieves a set of outer-referenced columns by the given expression, `expr`. -/// Note that the `Expr::to_columns()` function doesn't return these columns. -/// -/// # Parameters -/// -/// * `expr` - The expression to analyze for outer-referenced columns. -/// -/// # Returns -/// -/// If the function can safely infer all outer-referenced columns, returns a -/// `Some(HashSet)` containing these columns. Otherwise, returns `None`. -fn outer_columns(expr: &Expr) -> Option> { - let mut columns = HashSet::new(); - outer_columns_helper(expr, &mut columns).then_some(columns) -} - -/// A recursive subroutine that accumulates outer-referenced columns by the +/// Accumulates outer-referenced columns by the /// given expression, `expr`. /// /// # Parameters @@ -598,88 +583,31 @@ fn outer_columns(expr: &Expr) -> Option> { /// * `expr` - The expression to analyze for outer-referenced columns. /// * `columns` - A mutable reference to a `HashSet` where detected /// columns are collected. -/// -/// Returns `true` if it can safely collect all outer-referenced columns. -/// Otherwise, returns `false`. -fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) -> bool { - match expr { - Expr::OuterReferenceColumn(_, col) => { - columns.insert(col.clone()); - true - } - Expr::BinaryExpr(binary_expr) => { - outer_columns_helper(&binary_expr.left, columns) - && outer_columns_helper(&binary_expr.right, columns) - } - Expr::ScalarSubquery(subquery) => { - let exprs = subquery.outer_ref_columns.iter(); - outer_columns_helper_multi(exprs, columns) - } - Expr::Exists(exists) => { - let exprs = exists.subquery.outer_ref_columns.iter(); - outer_columns_helper_multi(exprs, columns) - } - Expr::Alias(alias) => outer_columns_helper(&alias.expr, columns), - Expr::InSubquery(insubquery) => { - let exprs = insubquery.subquery.outer_ref_columns.iter(); - outer_columns_helper_multi(exprs, columns) - } - Expr::IsNotNull(expr) | Expr::IsNull(expr) => outer_columns_helper(expr, columns), - Expr::Cast(cast) => outer_columns_helper(&cast.expr, columns), - Expr::Sort(sort) => outer_columns_helper(&sort.expr, columns), - Expr::AggregateFunction(aggregate_fn) => { - outer_columns_helper_multi(aggregate_fn.args.iter(), columns) - && aggregate_fn - .order_by - .as_ref() - .map_or(true, |obs| outer_columns_helper_multi(obs.iter(), columns)) - && aggregate_fn - .filter - .as_ref() - .map_or(true, |filter| outer_columns_helper(filter, columns)) - } - Expr::WindowFunction(window_fn) => { - outer_columns_helper_multi(window_fn.args.iter(), columns) - && outer_columns_helper_multi(window_fn.order_by.iter(), columns) - && outer_columns_helper_multi(window_fn.partition_by.iter(), columns) - } - Expr::GroupingSet(groupingset) => match groupingset { - GroupingSet::GroupingSets(multi_exprs) => multi_exprs - .iter() - .all(|e| outer_columns_helper_multi(e.iter(), columns)), - GroupingSet::Cube(exprs) | GroupingSet::Rollup(exprs) => { - outer_columns_helper_multi(exprs.iter(), columns) +fn outer_columns(expr: &Expr, columns: &mut HashSet) { + // inspect_expr_pre doesn't handle subquery references, so find them explicitly + inspect_expr_pre(expr, |expr| { + match expr { + Expr::OuterReferenceColumn(_, col) => { + columns.insert(col.clone()); } - }, - Expr::ScalarFunction(scalar_fn) => { - outer_columns_helper_multi(scalar_fn.args.iter(), columns) - } - Expr::Like(like) => { - outer_columns_helper(&like.expr, columns) - && outer_columns_helper(&like.pattern, columns) - } - Expr::InList(in_list) => { - outer_columns_helper(&in_list.expr, columns) - && outer_columns_helper_multi(in_list.list.iter(), columns) - } - Expr::Case(case) => { - let when_then_exprs = case - .when_then_expr - .iter() - .flat_map(|(first, second)| [first.as_ref(), second.as_ref()]); - outer_columns_helper_multi(when_then_exprs, columns) - && case - .expr - .as_ref() - .map_or(true, |expr| outer_columns_helper(expr, columns)) - && case - .else_expr - .as_ref() - .map_or(true, |expr| outer_columns_helper(expr, columns)) - } - Expr::Column(_) | Expr::Literal(_) | Expr::Wildcard { .. } => true, - _ => false, - } + Expr::ScalarSubquery(subquery) => { + outer_columns_helper_multi(&subquery.outer_ref_columns, columns); + } + Expr::Exists(exists) => { + outer_columns_helper_multi(&exists.subquery.outer_ref_columns, columns); + } + Expr::InSubquery(insubquery) => { + outer_columns_helper_multi( + &insubquery.subquery.outer_ref_columns, + columns, + ); + } + _ => {} + }; + Ok(()) as Result<()> + }) + // unwrap: closure above never returns Err, so can not be Err here + .unwrap(); } /// A recursive subroutine that accumulates outer-referenced columns by the @@ -690,14 +618,11 @@ fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) -> bool { /// * `exprs` - The expressions to analyze for outer-referenced columns. /// * `columns` - A mutable reference to a `HashSet` where detected /// columns are collected. -/// -/// Returns `true` if it can safely collect all outer-referenced columns. -/// Otherwise, returns `false`. fn outer_columns_helper_multi<'a>( - mut exprs: impl Iterator, + exprs: impl IntoIterator, columns: &mut HashSet, -) -> bool { - exprs.all(|e| outer_columns_helper(e, columns)) +) { + exprs.into_iter().for_each(|e| outer_columns(e, columns)); } /// Generates the required expressions (columns) that reside at `indices` of @@ -765,14 +690,8 @@ fn indices_referred_by_expr( expr: &Expr, ) -> Result> { let mut cols = expr.to_columns()?; - // Get outer-referenced columns: - if let Some(outer_cols) = outer_columns(expr) { - cols.extend(outer_cols); - } else { - // Expression is not known to contain outer columns or not. Hence, do - // not assume anything and require all the schema indices at the input: - return Ok((0..input_schema.fields().len()).collect()); - } + // Get outer-referenced (subquery) columns: + outer_columns(expr, &mut cols); Ok(cols .iter() .flat_map(|col| input_schema.index_of_column(col)) @@ -978,8 +897,8 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Result, TableReference}; use datafusion_expr::{ - binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, - table_scan, Expr, LogicalPlan, Operator, + binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, not, + table_scan, try_cast, Expr, Like, LogicalPlan, Operator, }; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -1060,4 +979,187 @@ mod tests { \n TableScan: ?table? projection=[]"; assert_optimized_plan_equal(&plan, expected) } + + #[test] + fn test_struct_field_push_down() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new_struct( + "s", + vec![ + Field::new("x", DataType::Int64, false), + Field::new("y", DataType::Int64, false), + ], + false, + ), + ])); + + let table_scan = table_scan(TableReference::none(), &schema, None)?.build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("s").field("x")])? + .build()?; + let expected = "Projection: (?table?.s)[x]\ + \n TableScan: ?table? projection=[s]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_neg_push_down() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![-col("a")])? + .build()?; + + let expected = "Projection: (- test.a)\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_null() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_null()])? + .build()?; + + let expected = "Projection: test.a IS NULL\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_null() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_null()])? + .build()?; + + let expected = "Projection: test.a IS NOT NULL\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_true() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_true()])? + .build()?; + + let expected = "Projection: test.a IS TRUE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_true() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_true()])? + .build()?; + + let expected = "Projection: test.a IS NOT TRUE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_false() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_false()])? + .build()?; + + let expected = "Projection: test.a IS FALSE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_false() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_false()])? + .build()?; + + let expected = "Projection: test.a IS NOT FALSE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_unknown() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_unknown()])? + .build()?; + + let expected = "Projection: test.a IS UNKNOWN\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_unknown() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_unknown()])? + .build()?; + + let expected = "Projection: test.a IS NOT UNKNOWN\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_not() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![not(col("a"))])? + .build()?; + + let expected = "Projection: NOT test.a\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_try_cast() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![try_cast(col("a"), DataType::Float64)])? + .build()?; + + let expected = "Projection: TRY_CAST(test.a AS Float64)\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_similar_to() -> Result<()> { + let table_scan = test_table_scan()?; + let expr = Box::new(col("a")); + let pattern = Box::new(lit("[0-9]")); + let similar_to_expr = + Expr::SimilarTo(Like::new(false, expr, pattern, None, false)); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![similar_to_expr])? + .build()?; + + let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_between() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").between(lit(1), lit(3))])? + .build()?; + + let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 7d09aec7e748..3ba343003e33 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -27,7 +27,7 @@ use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::SimplifyInfo; use arrow::{ - array::new_null_array, + array::{new_null_array, AsArray}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; @@ -381,12 +381,8 @@ impl<'a> ConstEvaluator<'a> { return Ok(s); } - let phys_expr = create_physical_expr( - &expr, - &self.input_schema, - &self.input_batch.schema(), - self.execution_props, - )?; + let phys_expr = + create_physical_expr(&expr, &self.input_schema, self.execution_props)?; let col_val = phys_expr.evaluate(&self.input_batch)?; match col_val { ColumnarValue::Array(a) => { @@ -396,7 +392,7 @@ impl<'a> ConstEvaluator<'a> { a.len() ) } else if as_list_array(&a).is_ok() || as_large_list_array(&a).is_ok() { - Ok(ScalarValue::List(a)) + Ok(ScalarValue::List(a.as_list().to_owned().into())) } else { // Non-ListArray ScalarValue::try_from_array(&a, 0) diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 1efae424cc69..2d263a42e0ff 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -186,7 +186,6 @@ mod tests { use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; - use arrow_array::cast::as_list_array; use arrow_array::types::Int32Type; use arrow_array::{Array, ListArray}; use arrow_buffer::OffsetBuffer; @@ -196,10 +195,7 @@ mod tests { // arrow::compute::sort cann't sort ListArray directly, so we need to sort the inner primitive array and wrap it back into ListArray. fn sort_list_inner(arr: ScalarValue) -> ScalarValue { let arr = match arr { - ScalarValue::List(arr) => { - let list_arr = as_list_array(&arr); - list_arr.value(0) - } + ScalarValue::List(arr) => arr.value(0), _ => { panic!("Expected ScalarValue::List, got {:?}", arr) } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index c2fd32a96c4f..021c33fb94a7 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -15,21 +15,32 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, TimeUnit}; +use arrow_array::types::{ + ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_array::PrimitiveArray; use std::any::Any; +use std::cmp::Eq; use std::fmt::Debug; +use std::hash::Hash; use std::sync::Arc; use ahash::RandomState; use arrow::array::{Array, ArrayRef}; use std::collections::HashSet; -use crate::aggregate::utils::down_cast_any_ref; +use crate::aggregate::utils::{down_cast_any_ref, Hashable}; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use datafusion_common::cast::{as_list_array, as_primitive_array}; +use datafusion_common::utils::array_into_list_array; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; type DistinctScalarValues = ScalarValue; @@ -60,6 +71,18 @@ impl DistinctCount { } } +macro_rules! native_distinct_count_accumulator { + ($TYPE:ident) => {{ + Ok(Box::new(NativeDistinctCountAccumulator::<$TYPE>::new())) + }}; +} + +macro_rules! float_distinct_count_accumulator { + ($TYPE:ident) => {{ + Ok(Box::new(FloatDistinctCountAccumulator::<$TYPE>::new())) + }}; +} + impl AggregateExpr for DistinctCount { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -83,10 +106,57 @@ impl AggregateExpr for DistinctCount { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: self.state_data_type.clone(), - })) + use DataType::*; + use TimeUnit::*; + + match &self.state_data_type { + Int8 => native_distinct_count_accumulator!(Int8Type), + Int16 => native_distinct_count_accumulator!(Int16Type), + Int32 => native_distinct_count_accumulator!(Int32Type), + Int64 => native_distinct_count_accumulator!(Int64Type), + UInt8 => native_distinct_count_accumulator!(UInt8Type), + UInt16 => native_distinct_count_accumulator!(UInt16Type), + UInt32 => native_distinct_count_accumulator!(UInt32Type), + UInt64 => native_distinct_count_accumulator!(UInt64Type), + Decimal128(_, _) => native_distinct_count_accumulator!(Decimal128Type), + Decimal256(_, _) => native_distinct_count_accumulator!(Decimal256Type), + + Date32 => native_distinct_count_accumulator!(Date32Type), + Date64 => native_distinct_count_accumulator!(Date64Type), + Time32(Millisecond) => { + native_distinct_count_accumulator!(Time32MillisecondType) + } + Time32(Second) => { + native_distinct_count_accumulator!(Time32SecondType) + } + Time64(Microsecond) => { + native_distinct_count_accumulator!(Time64MicrosecondType) + } + Time64(Nanosecond) => { + native_distinct_count_accumulator!(Time64NanosecondType) + } + Timestamp(Microsecond, _) => { + native_distinct_count_accumulator!(TimestampMicrosecondType) + } + Timestamp(Millisecond, _) => { + native_distinct_count_accumulator!(TimestampMillisecondType) + } + Timestamp(Nanosecond, _) => { + native_distinct_count_accumulator!(TimestampNanosecondType) + } + Timestamp(Second, _) => { + native_distinct_count_accumulator!(TimestampSecondType) + } + + Float16 => float_distinct_count_accumulator!(Float16Type), + Float32 => float_distinct_count_accumulator!(Float32Type), + Float64 => float_distinct_count_accumulator!(Float64Type), + + _ => Ok(Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: self.state_data_type.clone(), + })), + } } fn name(&self) -> &str { @@ -192,6 +262,182 @@ impl Accumulator for DistinctCountAccumulator { } } +#[derive(Debug)] +struct NativeDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, + T::Native: Eq + Hash, +{ + values: HashSet, +} + +impl NativeDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, + T::Native: Eq + Hash, +{ + fn new() -> Self { + Self { + values: HashSet::default(), + } + } +} + +impl Accumulator for NativeDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send + Debug, + T::Native: Eq + Hash, +{ + fn state(&self) -> Result> { + let arr = Arc::new(PrimitiveArray::::from_iter_values( + self.values.iter().cloned(), + )) as ArrayRef; + let list = Arc::new(array_into_list_array(arr)); + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = as_primitive_array::(&values[0])?; + arr.iter().for_each(|value| { + if let Some(value) = value { + self.values.insert(value); + } + }); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + let list = as_primitive_array::(&list)?; + self.values.extend(list.values()) + }; + Ok(()) + }) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) + / 7) + .next_power_of_two(); + + // Size of accumulator + // + size of entry * number of buckets + // + 1 byte for each bucket + // + fixed size of HashSet + std::mem::size_of_val(self) + + std::mem::size_of::() * estimated_buckets + + estimated_buckets + + std::mem::size_of_val(&self.values) + } +} + +#[derive(Debug)] +struct FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + values: HashSet, RandomState>, +} + +impl FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + fn new() -> Self { + Self { + values: HashSet::default(), + } + } +} + +impl Accumulator for FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send + Debug, +{ + fn state(&self) -> Result> { + let arr = Arc::new(PrimitiveArray::::from_iter_values( + self.values.iter().map(|v| v.0), + )) as ArrayRef; + let list = Arc::new(array_into_list_array(arr)); + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = as_primitive_array::(&values[0])?; + arr.iter().for_each(|value| { + if let Some(value) = value { + self.values.insert(Hashable(value)); + } + }); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + let list = as_primitive_array::(&list)?; + self.values + .extend(list.values().iter().map(|v| Hashable(*v))); + }; + Ok(()) + }) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) + / 7) + .next_power_of_two(); + + // Size of accumulator + // + size of entry * number of buckets + // + 1 byte for each bucket + // + fixed size of HashSet + std::mem::size_of_val(self) + + std::mem::size_of::() * estimated_buckets + + estimated_buckets + + std::mem::size_of_val(&self.values) + } +} + #[cfg(test)] mod tests { use crate::expressions::NoOp; @@ -206,6 +452,8 @@ mod tests { Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; + use arrow_array::Decimal256Array; + use arrow_buffer::i256; use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; use datafusion_common::internal_err; use datafusion_common::DataFusionError; @@ -367,6 +615,35 @@ mod tests { }}; } + macro_rules! test_count_distinct_update_batch_bigint { + ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ + let values: Vec> = vec![ + Some(i256::from(1)), + Some(i256::from(1)), + None, + Some(i256::from(3)), + Some(i256::from(2)), + None, + Some(i256::from(2)), + Some(i256::from(3)), + Some(i256::from(1)), + ]; + + let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; + + let (states, result) = run_update_batch(&arrays)?; + + let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); + state_vec.sort(); + + assert_eq!(states.len(), 1); + assert_eq!(state_vec, vec![i256::from(1), i256::from(2), i256::from(3)]); + assert_eq!(result, ScalarValue::Int64(Some(3))); + + Ok(()) + }}; + } + #[test] fn count_distinct_update_batch_i8() -> Result<()> { test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8) @@ -417,6 +694,11 @@ mod tests { test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64) } + #[test] + fn count_distinct_update_batch_i256() -> Result<()> { + test_count_distinct_update_batch_bigint!(Decimal256Array, Decimal256Type, i256) + } + #[test] fn count_distinct_update_batch_boolean() -> Result<()> { let get_count = |data: BooleanArray| -> Result<(Vec, i64)> { diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index 0cf4a90ab8cc..6dbb39224629 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -25,11 +25,11 @@ use arrow::array::{Array, ArrayRef}; use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::{ArrowNativeTypeOp, ArrowPrimitiveType}; -use arrow_buffer::{ArrowNativeType, ToByteSlice}; +use arrow_buffer::ArrowNativeType; use std::collections::HashSet; use crate::aggregate::sum::downcast_sum; -use crate::aggregate::utils::down_cast_any_ref; +use crate::aggregate::utils::{down_cast_any_ref, Hashable}; use crate::{AggregateExpr, PhysicalExpr}; use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::sum_return_type; @@ -119,24 +119,6 @@ impl PartialEq for DistinctSum { } } -/// A wrapper around a type to provide hash for floats -#[derive(Copy, Clone)] -struct Hashable(T); - -impl std::hash::Hash for Hashable { - fn hash(&self, state: &mut H) { - self.0.to_byte_slice().hash(state) - } -} - -impl PartialEq for Hashable { - fn eq(&self, other: &Self) -> bool { - self.0.is_eq(other.0) - } -} - -impl Eq for Hashable {} - struct DistinctSumAccumulator { values: HashSet, RandomState>, data_type: DataType, diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs b/datafusion/physical-expr/src/aggregate/tdigest.rs index 90f5244f477d..78708df94c25 100644 --- a/datafusion/physical-expr/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr/src/aggregate/tdigest.rs @@ -28,7 +28,6 @@ //! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h use arrow::datatypes::DataType; -use arrow_array::cast::as_list_array; use arrow_array::types::Float64Type; use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; @@ -606,11 +605,10 @@ impl TDigest { let centroids: Vec<_> = match &state[5] { ScalarValue::List(arr) => { - let list_array = as_list_array(arr); - let arr = list_array.values(); + let array = arr.values(); let f64arr = - as_primitive_array::(arr).expect("expected f64 array"); + as_primitive_array::(array).expect("expected f64 array"); f64arr .values() .chunks(2) diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index 9777158da133..d73c46a0f687 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -28,7 +28,7 @@ use arrow_array::types::{ Decimal128Type, DecimalType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use arrow_buffer::ArrowNativeType; +use arrow_buffer::{ArrowNativeType, ToByteSlice}; use arrow_schema::{DataType, Field, SortOptions}; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::Accumulator; @@ -211,3 +211,21 @@ pub(crate) fn ordering_fields( pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { ordering_req.iter().map(|item| item.options).collect() } + +/// A wrapper around a type to provide hash for floats +#[derive(Copy, Clone, Debug)] +pub(crate) struct Hashable(pub T); + +impl std::hash::Hash for Hashable { + fn hash(&self, state: &mut H) { + self.0.to_byte_slice().hash(state) + } +} + +impl PartialEq for Hashable { + fn eq(&self, other: &Self) -> bool { + self.0.is_eq(other.0) + } +} + +impl Eq for Hashable {} diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 15330af640ae..5b35c4b9d8fb 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -27,7 +27,7 @@ use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::datatypes::{DataType, Field, UInt64Type}; use arrow::row::{RowConverter, SortField}; -use arrow_buffer::NullBuffer; +use arrow_buffer::{ArrowNativeType, NullBuffer}; use arrow_schema::{FieldRef, SortOptions}; use datafusion_common::cast::{ @@ -36,7 +36,8 @@ use datafusion_common::cast::{ }; use datafusion_common::utils::{array_into_list_array, list_ndims}; use datafusion_common::{ - exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, + exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, + DataFusionError, Result, ScalarValue, }; use itertools::Itertools; @@ -529,7 +530,7 @@ fn general_except( pub fn array_except(args: &[ArrayRef]) -> Result { if args.len() != 2 { - return internal_err!("array_except needs two arguments"); + return exec_err!("array_except needs two arguments"); } let array1 = &args[0]; @@ -894,7 +895,7 @@ pub fn gen_range(args: &[ArrayRef]) -> Result { as_int64_array(&args[1])?, Some(as_int64_array(&args[2])?), ), - _ => return internal_err!("gen_range expects 1 to 3 arguments"), + _ => return exec_err!("gen_range expects 1 to 3 arguments"), }; let mut values = vec![]; @@ -948,7 +949,7 @@ pub fn array_sort(args: &[ArrayRef]) -> Result { nulls_first: order_nulls_first(nulls_first)?, }) } - _ => return internal_err!("array_sort expects 1 to 3 arguments"), + _ => return exec_err!("array_sort expects 1 to 3 arguments"), }; let list_array = as_list_array(&args[0])?; @@ -994,7 +995,7 @@ fn order_desc(modifier: &str) -> Result { match modifier.to_uppercase().as_str() { "DESC" => Ok(true), "ASC" => Ok(false), - _ => internal_err!("the second parameter of array_sort expects DESC or ASC"), + _ => exec_err!("the second parameter of array_sort expects DESC or ASC"), } } @@ -1002,7 +1003,7 @@ fn order_nulls_first(modifier: &str) -> Result { match modifier.to_uppercase().as_str() { "NULLS FIRST" => Ok(true), "NULLS LAST" => Ok(false), - _ => internal_err!( + _ => exec_err!( "the third parameter of array_sort expects NULLS FIRST or NULLS LAST" ), } @@ -1190,7 +1191,10 @@ pub fn array_concat(args: &[ArrayRef]) -> Result { } } - concat_internal::(new_args.as_slice()) + match &args[0].data_type() { + DataType::LargeList(_) => concat_internal::(new_args.as_slice()), + _ => concat_internal::(new_args.as_slice()), + } } /// Array_empty SQL function @@ -1208,7 +1212,7 @@ pub fn array_empty(args: &[ArrayRef]) -> Result { match array_type { DataType::List(_) => array_empty_dispatch::(&args[0]), DataType::LargeList(_) => array_empty_dispatch::(&args[0]), - _ => internal_err!("array_empty does not support type '{array_type:?}'."), + _ => exec_err!("array_empty does not support type '{array_type:?}'."), } } @@ -1239,7 +1243,7 @@ pub fn array_repeat(args: &[ArrayRef]) -> Result { let list_array = as_large_list_array(element)?; general_list_repeat::(list_array, count_array) } - _ => general_repeat(element, count_array), + _ => general_repeat::(element, count_array), } } @@ -1255,7 +1259,10 @@ pub fn array_repeat(args: &[ArrayRef]) -> Result { /// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]] /// ) /// ``` -fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result { +fn general_repeat( + array: &ArrayRef, + count_array: &Int64Array, +) -> Result { let data_type = array.data_type(); let mut new_values = vec![]; @@ -1288,7 +1295,7 @@ fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result = new_values.iter().map(|a| a.as_ref()).collect(); let values = compute::concat(&new_values)?; - Ok(Arc::new(ListArray::try_new( + Ok(Arc::new(GenericListArray::::try_new( Arc::new(Field::new("item", data_type.to_owned(), true)), OffsetBuffer::from_lengths(count_vec), values, @@ -1598,7 +1605,9 @@ fn array_remove_internal( let list_array = array.as_list::(); general_remove::(list_array, element_array, arr_n) } - _ => internal_err!("array_remove_all expects a list array"), + array_type => { + exec_err!("array_remove_all does not support type '{array_type:?}'.") + } } } @@ -2022,8 +2031,21 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { ) -> Result<&mut String> { match arr.data_type() { DataType::List(..) => { - let list_array = downcast_arg!(arr, ListArray); + let list_array = as_list_array(&arr)?; + for i in 0..list_array.len() { + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; + } + Ok(arg) + } + DataType::LargeList(..) => { + let list_array = as_large_list_array(&arr)?; for i in 0..list_array.len() { compute_array_to_string( arg, @@ -2055,35 +2077,61 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { } } - let mut arg = String::from(""); - let mut res: Vec> = Vec::new(); - - match arr.data_type() { - DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => { - let list_array = arr.as_list::(); - for (arr, &delimiter) in list_array.iter().zip(delimiters.iter()) { - if let (Some(arr), Some(delimiter)) = (arr, delimiter) { - arg = String::from(""); - let s = compute_array_to_string( - &mut arg, - arr, - delimiter.to_string(), - null_string.clone(), - with_null_string, - )? - .clone(); - - if let Some(s) = s.strip_suffix(delimiter) { - res.push(Some(s.to_string())); - } else { - res.push(Some(s)); - } + fn generate_string_array( + list_arr: &GenericListArray, + delimiters: Vec>, + null_string: String, + with_null_string: bool, + ) -> Result { + let mut res: Vec> = Vec::new(); + for (arr, &delimiter) in list_arr.iter().zip(delimiters.iter()) { + if let (Some(arr), Some(delimiter)) = (arr, delimiter) { + let mut arg = String::from(""); + let s = compute_array_to_string( + &mut arg, + arr, + delimiter.to_string(), + null_string.clone(), + with_null_string, + )? + .clone(); + + if let Some(s) = s.strip_suffix(delimiter) { + res.push(Some(s.to_string())); } else { - res.push(None); + res.push(Some(s)); } + } else { + res.push(None); } } + + Ok(StringArray::from(res)) + } + + let arr_type = arr.data_type(); + let string_arr = match arr_type { + DataType::List(_) | DataType::FixedSizeList(_, _) => { + let list_array = as_list_array(&arr)?; + generate_string_array::( + list_array, + delimiters, + null_string, + with_null_string, + )? + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&arr)?; + generate_string_array::( + list_array, + delimiters, + null_string, + with_null_string, + )? + } _ => { + let mut arg = String::from(""); + let mut res: Vec> = Vec::new(); // delimiter length is 1 assert_eq!(delimiters.len(), 1); let delimiter = delimiters[0].unwrap(); @@ -2102,10 +2150,11 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { } else { res.push(Some(s)); } + StringArray::from(res) } - } + }; - Ok(Arc::new(StringArray::from(res))) + Ok(Arc::new(string_arr)) } /// Cardinality SQL function @@ -2114,16 +2163,31 @@ pub fn cardinality(args: &[ArrayRef]) -> Result { return exec_err!("cardinality expects one argument"); } - let list_array = as_list_array(&args[0])?.clone(); + match &args[0].data_type() { + DataType::List(_) => { + let list_array = as_list_array(&args[0])?; + generic_list_cardinality::(list_array) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&args[0])?; + generic_list_cardinality::(list_array) + } + other => { + exec_err!("cardinality does not support type '{:?}'", other) + } + } +} - let result = list_array +fn generic_list_cardinality( + array: &GenericListArray, +) -> Result { + let result = array .iter() .map(|arr| match compute_array_dims(arr)? { Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::())), None => Ok(None), }) .collect::>()?; - Ok(Arc::new(result) as ArrayRef) } @@ -2205,10 +2269,7 @@ pub fn array_length(args: &[ArrayRef]) -> Result { match &args[0].data_type() { DataType::List(_) => array_length_dispatch::(args), DataType::LargeList(_) => array_length_dispatch::(args), - _ => internal_err!( - "array_length does not support type '{:?}'", - args[0].data_type() - ), + array_type => exec_err!("array_length does not support type '{array_type:?}'"), } } @@ -2233,11 +2294,8 @@ pub fn array_dims(args: &[ArrayRef]) -> Result { .map(compute_array_dims) .collect::>>()? } - _ => { - return exec_err!( - "array_dims does not support type '{:?}'", - args[0].data_type() - ); + array_type => { + return exec_err!("array_dims does not support type '{array_type:?}'"); } }; @@ -2386,7 +2444,7 @@ pub fn array_has_any(args: &[ArrayRef]) -> Result { DataType::LargeList(_) => { general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) } - _ => internal_err!("array_has_any does not support type '{array_type:?}'."), + _ => exec_err!("array_has_any does not support type '{array_type:?}'."), } } @@ -2405,7 +2463,7 @@ pub fn array_has_all(args: &[ArrayRef]) -> Result { DataType::LargeList(_) => { general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) } - _ => internal_err!("array_has_all does not support type '{array_type:?}'."), + _ => exec_err!("array_has_all does not support type '{array_type:?}'."), } } @@ -2488,7 +2546,7 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result { - return internal_err!( + return exec_err!( "Expect string_to_array function to take two or three parameters" ) } @@ -2556,8 +2614,102 @@ pub fn array_distinct(args: &[ArrayRef]) -> Result { let array = as_large_list_array(&args[0])?; general_array_distinct(array, field) } - _ => internal_err!("array_distinct only support list array"), + array_type => exec_err!("array_distinct does not support type '{array_type:?}'"), + } +} + +/// array_resize SQL function +pub fn array_resize(arg: &[ArrayRef]) -> Result { + if arg.len() < 2 || arg.len() > 3 { + return exec_err!("array_resize needs two or three arguments"); + } + + let new_len = as_int64_array(&arg[1])?; + let new_element = if arg.len() == 3 { + Some(arg[2].clone()) + } else { + None + }; + + match &arg[0].data_type() { + DataType::List(field) => { + let array = as_list_array(&arg[0])?; + general_list_resize::(array, new_len, field, new_element) + } + DataType::LargeList(field) => { + let array = as_large_list_array(&arg[0])?; + general_list_resize::(array, new_len, field, new_element) + } + array_type => exec_err!("array_resize does not support type '{array_type:?}'."), + } +} + +/// array_resize keep the original array and append the default element to the end +fn general_list_resize( + array: &GenericListArray, + count_array: &Int64Array, + field: &FieldRef, + default_element: Option, +) -> Result +where + O: TryInto, +{ + let data_type = array.value_type(); + + let values = array.values(); + let original_data = values.to_data(); + + // create default element array + let default_element = if let Some(default_element) = default_element { + default_element + } else { + let null_scalar = ScalarValue::try_from(&data_type)?; + null_scalar.to_array_of_size(original_data.len())? + }; + let default_value_data = default_element.to_data(); + + // create a mutable array to store the original data + let capacity = Capacities::Array(original_data.len() + default_value_data.len()); + let mut offsets = vec![O::usize_as(0)]; + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &default_value_data], + false, + capacity, + ); + + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let count = count_array.value(row_index).to_usize().ok_or_else(|| { + internal_datafusion_err!("array_resize: failed to convert size to usize") + })?; + let count = O::usize_as(count); + let start = offset_window[0]; + if start + count > offset_window[1] { + let extra_count = + (start + count - offset_window[1]).try_into().map_err(|_| { + internal_datafusion_err!( + "array_resize: failed to convert size to i64" + ) + })?; + let end = offset_window[1]; + mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap()); + // append default element + for _ in 0..extra_count { + mutable.extend(1, row_index, row_index + 1); + } + } else { + let end = start + count; + mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap()); + }; + offsets.push(offsets[row_index] + count); } + + let data = mutable.freeze(); + Ok(Arc::new(GenericListArray::::try_new( + field.clone(), + OffsetBuffer::::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) } #[cfg(test)] diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index c17081398cb8..8c4078dbce8c 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -20,9 +20,7 @@ mod kernels; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use crate::array_expressions::{ - array_append, array_concat, array_has_all, array_prepend, -}; +use crate::array_expressions::array_has_all; use crate::expressions::datum::{apply, apply_cmp}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::physical_expr::down_cast_any_ref; @@ -598,12 +596,7 @@ impl BinaryExpr { BitwiseXor => bitwise_xor_dyn(left, right), BitwiseShiftRight => bitwise_shift_right_dyn(left, right), BitwiseShiftLeft => bitwise_shift_left_dyn(left, right), - StringConcat => match (left_data_type, right_data_type) { - (DataType::List(_), DataType::List(_)) => array_concat(&[left, right]), - (DataType::List(_), _) => array_append(&[left, right]), - (_, DataType::List(_)) => array_prepend(&[left, right]), - _ => binary_string_array_op!(left, right, concat_elements), - }, + StringConcat => binary_string_array_op!(left, right, concat_elements), AtArrow => array_has_all(&[left, right]), ArrowAt => array_has_all(&[right, left]), } diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 53de85843919..66e22d2302de 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -419,6 +419,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::Cardinality => { Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args)) } + BuiltinScalarFunction::ArrayResize => { + Arc::new(|args| make_scalar_function(array_expressions::array_resize)(args)) + } BuiltinScalarFunction::MakeArray => { Arc::new(|args| make_scalar_function(array_expressions::make_array)(args)) } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 9c212cb81f6b..09b8da836c30 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -43,29 +43,17 @@ use std::sync::Arc; /// * `e` - The logical expression /// * `input_dfschema` - The DataFusion schema for the input, used to resolve `Column` references /// to qualified or unqualified fields by name. -/// * `input_schema` - The Arrow schema for the input, used for determining expression data types -/// when performing type coercion. pub fn create_physical_expr( e: &Expr, input_dfschema: &DFSchema, - input_schema: &Schema, execution_props: &ExecutionProps, ) -> Result> { - if input_schema.fields.len() != input_dfschema.fields().len() { - return internal_err!( - "create_physical_expr expected same number of fields, got \ - Arrow schema with {} and DataFusion schema with {}", - input_schema.fields.len(), - input_dfschema.fields().len() - ); - } + let input_schema: &Schema = &input_dfschema.into(); + match e { - Expr::Alias(Alias { expr, .. }) => Ok(create_physical_expr( - expr, - input_dfschema, - input_schema, - execution_props, - )?), + Expr::Alias(Alias { expr, .. }) => { + Ok(create_physical_expr(expr, input_dfschema, execution_props)?) + } Expr::Column(c) => { let idx = input_dfschema.index_of_column(c)?; Ok(Arc::new(Column::new(&c.name, idx))) @@ -96,12 +84,7 @@ pub fn create_physical_expr( Operator::IsNotDistinctFrom, Expr::Literal(ScalarValue::Boolean(Some(true))), ); - create_physical_expr( - &binary_op, - input_dfschema, - input_schema, - execution_props, - ) + create_physical_expr(&binary_op, input_dfschema, execution_props) } Expr::IsNotTrue(expr) => { let binary_op = binary_expr( @@ -109,12 +92,7 @@ pub fn create_physical_expr( Operator::IsDistinctFrom, Expr::Literal(ScalarValue::Boolean(Some(true))), ); - create_physical_expr( - &binary_op, - input_dfschema, - input_schema, - execution_props, - ) + create_physical_expr(&binary_op, input_dfschema, execution_props) } Expr::IsFalse(expr) => { let binary_op = binary_expr( @@ -122,12 +100,7 @@ pub fn create_physical_expr( Operator::IsNotDistinctFrom, Expr::Literal(ScalarValue::Boolean(Some(false))), ); - create_physical_expr( - &binary_op, - input_dfschema, - input_schema, - execution_props, - ) + create_physical_expr(&binary_op, input_dfschema, execution_props) } Expr::IsNotFalse(expr) => { let binary_op = binary_expr( @@ -135,12 +108,7 @@ pub fn create_physical_expr( Operator::IsDistinctFrom, Expr::Literal(ScalarValue::Boolean(Some(false))), ); - create_physical_expr( - &binary_op, - input_dfschema, - input_schema, - execution_props, - ) + create_physical_expr(&binary_op, input_dfschema, execution_props) } Expr::IsUnknown(expr) => { let binary_op = binary_expr( @@ -148,12 +116,7 @@ pub fn create_physical_expr( Operator::IsNotDistinctFrom, Expr::Literal(ScalarValue::Boolean(None)), ); - create_physical_expr( - &binary_op, - input_dfschema, - input_schema, - execution_props, - ) + create_physical_expr(&binary_op, input_dfschema, execution_props) } Expr::IsNotUnknown(expr) => { let binary_op = binary_expr( @@ -161,27 +124,12 @@ pub fn create_physical_expr( Operator::IsDistinctFrom, Expr::Literal(ScalarValue::Boolean(None)), ); - create_physical_expr( - &binary_op, - input_dfschema, - input_schema, - execution_props, - ) + create_physical_expr(&binary_op, input_dfschema, execution_props) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { // Create physical expressions for left and right operands - let lhs = create_physical_expr( - left, - input_dfschema, - input_schema, - execution_props, - )?; - let rhs = create_physical_expr( - right, - input_dfschema, - input_schema, - execution_props, - )?; + let lhs = create_physical_expr(left, input_dfschema, execution_props)?; + let rhs = create_physical_expr(right, input_dfschema, execution_props)?; // Note that the logical planner is responsible // for type coercion on the arguments (e.g. if one // argument was originally Int32 and one was @@ -201,18 +149,10 @@ pub fn create_physical_expr( if escape_char.is_some() { return exec_err!("LIKE does not support escape_char"); } - let physical_expr = create_physical_expr( - expr, - input_dfschema, - input_schema, - execution_props, - )?; - let physical_pattern = create_physical_expr( - pattern, - input_dfschema, - input_schema, - execution_props, - )?; + let physical_expr = + create_physical_expr(expr, input_dfschema, execution_props)?; + let physical_pattern = + create_physical_expr(pattern, input_dfschema, execution_props)?; like( *negated, *case_insensitive, @@ -226,7 +166,6 @@ pub fn create_physical_expr( Some(create_physical_expr( e.as_ref(), input_dfschema, - input_schema, execution_props, )?) } else { @@ -236,24 +175,14 @@ pub fn create_physical_expr( .when_then_expr .iter() .map(|(w, _)| { - create_physical_expr( - w.as_ref(), - input_dfschema, - input_schema, - execution_props, - ) + create_physical_expr(w.as_ref(), input_dfschema, execution_props) }) .collect::>>()?; let then_expr = case .when_then_expr .iter() .map(|(_, t)| { - create_physical_expr( - t.as_ref(), - input_dfschema, - input_schema, - execution_props, - ) + create_physical_expr(t.as_ref(), input_dfschema, execution_props) }) .collect::>>()?; let when_then_expr: Vec<(Arc, Arc)> = @@ -267,7 +196,6 @@ pub fn create_physical_expr( Some(create_physical_expr( e.as_ref(), input_dfschema, - input_schema, execution_props, )?) } else { @@ -276,35 +204,30 @@ pub fn create_physical_expr( Ok(expressions::case(expr, when_then_expr, else_expr)?) } Expr::Cast(Cast { expr, data_type }) => expressions::cast( - create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, + create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, data_type.clone(), ), Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast( - create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, + create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, data_type.clone(), ), - Expr::Not(expr) => expressions::not(create_physical_expr( - expr, - input_dfschema, - input_schema, - execution_props, - )?), + Expr::Not(expr) => { + expressions::not(create_physical_expr(expr, input_dfschema, execution_props)?) + } Expr::Negative(expr) => expressions::negative( - create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, + create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, ), Expr::IsNull(expr) => expressions::is_null(create_physical_expr( expr, input_dfschema, - input_schema, execution_props, )?), Expr::IsNotNull(expr) => expressions::is_not_null(create_physical_expr( expr, input_dfschema, - input_schema, execution_props, )?), Expr::GetIndexedField(GetIndexedField { expr, field }) => { @@ -313,37 +236,25 @@ pub fn create_physical_expr( GetFieldAccessExpr::NamedStructField { name: name.clone() } } GetFieldAccess::ListIndex { key } => GetFieldAccessExpr::ListIndex { - key: create_physical_expr( - key, - input_dfschema, - input_schema, - execution_props, - )?, + key: create_physical_expr(key, input_dfschema, execution_props)?, }, GetFieldAccess::ListRange { start, stop } => { GetFieldAccessExpr::ListRange { start: create_physical_expr( start, input_dfschema, - input_schema, execution_props, )?, stop: create_physical_expr( stop, input_dfschema, - input_schema, execution_props, )?, } } }; Ok(Arc::new(GetIndexedFieldExpr::new( - create_physical_expr( - expr, - input_dfschema, - input_schema, - execution_props, - )?, + create_physical_expr(expr, input_dfschema, execution_props)?, field, ))) } @@ -351,9 +262,7 @@ pub fn create_physical_expr( Expr::ScalarFunction(ScalarFunction { func_def, args }) => { let mut physical_args = args .iter() - .map(|e| { - create_physical_expr(e, input_dfschema, input_schema, execution_props) - }) + .map(|e| create_physical_expr(e, input_dfschema, execution_props)) .collect::>>()?; match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { @@ -386,20 +295,9 @@ pub fn create_physical_expr( low, high, }) => { - let value_expr = create_physical_expr( - expr, - input_dfschema, - input_schema, - execution_props, - )?; - let low_expr = - create_physical_expr(low, input_dfschema, input_schema, execution_props)?; - let high_expr = create_physical_expr( - high, - input_dfschema, - input_schema, - execution_props, - )?; + let value_expr = create_physical_expr(expr, input_dfschema, execution_props)?; + let low_expr = create_physical_expr(low, input_dfschema, execution_props)?; + let high_expr = create_physical_expr(high, input_dfschema, execution_props)?; // rewrite the between into the two binary operators let binary_expr = binary( @@ -424,22 +322,13 @@ pub fn create_physical_expr( Ok(expressions::lit(ScalarValue::Boolean(None))) } _ => { - let value_expr = create_physical_expr( - expr, - input_dfschema, - input_schema, - execution_props, - )?; + let value_expr = + create_physical_expr(expr, input_dfschema, execution_props)?; let list_exprs = list .iter() .map(|expr| { - create_physical_expr( - expr, - input_dfschema, - input_schema, - execution_props, - ) + create_physical_expr(expr, input_dfschema, execution_props) }) .collect::>>()?; expressions::in_list(value_expr, list_exprs, negated, input_schema) @@ -465,7 +354,7 @@ mod tests { let schema = Schema::new(vec![Field::new("letter", DataType::Utf8, false)]); let df_schema = DFSchema::try_from_qualified_schema("data", &schema)?; - let p = create_physical_expr(&expr, &df_schema, &schema, &ExecutionProps::new())?; + let p = create_physical_expr(&expr, &df_schema, &ExecutionProps::new())?; let batch = RecordBatch::try_new( Arc::new(schema), diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index 0ec1cf3f256b..de9ba33daf29 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -36,9 +36,82 @@ pub fn create_physical_expr( Ok(Arc::new(ScalarFunctionExpr::new( fun.name(), - fun.fun().clone(), + fun.fun(), input_phy_exprs.to_vec(), fun.return_type(&input_exprs_types)?, - None, + fun.monotonicity()?, ))) } + +#[cfg(test)] +mod tests { + use arrow::datatypes::Schema; + use arrow_schema::DataType; + use datafusion_common::Result; + use datafusion_expr::{ + ColumnarValue, FuncMonotonicity, ScalarUDF, ScalarUDFImpl, Signature, Volatility, + }; + + use crate::ScalarFunctionExpr; + + use super::create_physical_expr; + + #[test] + fn test_functions() -> Result<()> { + #[derive(Debug, Clone)] + struct TestScalarUDF { + signature: Signature, + } + + impl TestScalarUDF { + fn new() -> Self { + let signature = + Signature::exact(vec![DataType::Float64], Volatility::Immutable); + + Self { signature } + } + } + + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "my_fn" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!("my_fn is not implemented") + } + + fn monotonicity(&self) -> Result> { + Ok(Some(vec![Some(true)])) + } + } + + // create and register the udf + let udf = ScalarUDF::from(TestScalarUDF::new()); + + let p_expr = create_physical_expr(&udf, &[], &Schema::empty())?; + + assert_eq!( + p_expr + .as_any() + .downcast_ref::() + .unwrap() + .monotonicity(), + &Some(vec![Some(true)]) + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 0aee2af67fdd..26ee95f4793c 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -23,6 +23,7 @@ use crate::{split_conjunction, PhysicalExpr}; use datafusion_common::{Column, ScalarValue}; use datafusion_expr::Operator; use std::collections::{HashMap, HashSet}; +use std::fmt::{self, Display, Formatter}; use std::sync::Arc; /// Represents a guarantee that must be true for a boolean expression to @@ -222,6 +223,33 @@ impl LiteralGuarantee { } } +impl Display for LiteralGuarantee { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self.guarantee { + Guarantee::In => write!( + f, + "{} in ({})", + self.column.name, + self.literals + .iter() + .map(|lit| lit.to_string()) + .collect::>() + .join(", ") + ), + Guarantee::NotIn => write!( + f, + "{} not in ({})", + self.column.name, + self.literals + .iter() + .map(|lit| lit.to_string()) + .collect::>() + .join(", ") + ), + } + } +} + /// Combines conjuncts (aka terms `AND`ed together) into [`LiteralGuarantee`]s, /// preserving insert order #[derive(Debug, Default)] @@ -398,6 +426,7 @@ mod test { use datafusion_common::ToDFSchema; use datafusion_expr::expr_fn::*; use datafusion_expr::{lit, Expr}; + use itertools::Itertools; use std::sync::OnceLock; #[test] @@ -691,6 +720,11 @@ mod test { col("b").in_list(vec![lit(1), lit(2), lit(3)], true), vec![not_in_guarantee("b", [1, 2, 3])], ); + // b IN (1,2,3,4...24) + test_analyze( + col("b").in_list((1..25).map(lit).collect_vec(), false), + vec![in_guarantee("b", 1..25)], + ); } #[test] @@ -837,7 +871,7 @@ mod test { fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { let df_schema = schema.clone().to_dfschema().unwrap(); let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() + create_physical_expr(expr, &df_schema, &execution_props).unwrap() } // Schema for testing diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index c5b689496e90..1c638d9c184e 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -33,7 +33,9 @@ name = "datafusion_physical_plan" path = "src/lib.rs" [dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } @@ -54,11 +56,18 @@ once_cell = "1.18.0" parking_lot = { workspace = true } pin-project-lite = "^0.2.7" rand = { workspace = true } -tokio = { version = "1.28", features = ["sync", "fs", "parking_lot"] } +tokio = { version = "1.28", features = ["sync"] } uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] rstest = { workspace = true } rstest_reuse = "0.6.0" termtree = "0.4.1" -tokio = { version = "1.28", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } +tokio = { version = "1.28", features = [ + "macros", + "rt", + "rt-multi-thread", + "sync", + "fs", + "parking_lot", +] } diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index e7c7a42cf902..10ff9edb8912 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -17,18 +17,22 @@ use crate::aggregates::group_values::GroupValues; use ahash::RandomState; +use arrow::compute::cast; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; -use arrow_array::ArrayRef; -use arrow_schema::SchemaRef; +use arrow_array::{Array, ArrayRef}; +use arrow_schema::{DataType, SchemaRef}; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::Result; +use datafusion_common::{DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_physical_expr::EmitTo; use hashbrown::raw::RawTable; /// A [`GroupValues`] making use of [`Rows`] pub struct GroupValuesRows { + /// The output schema + schema: SchemaRef, + /// Converter for the group values row_converter: RowConverter, @@ -75,6 +79,7 @@ impl GroupValuesRows { let map = RawTable::with_capacity(0); Ok(Self { + schema, row_converter, map, map_size: 0, @@ -165,7 +170,7 @@ impl GroupValues for GroupValuesRows { .take() .expect("Can not emit from empty rows"); - let output = match emit_to { + let mut output = match emit_to { EmitTo::All => { let output = self.row_converter.convert_rows(&group_values)?; group_values.clear(); @@ -198,6 +203,20 @@ impl GroupValues for GroupValuesRows { } }; + // TODO: Materialize dictionaries in group keys (#7647) + for (field, array) in self.schema.fields.iter().zip(&mut output) { + let expected = field.data_type(); + if let DataType::Dictionary(_, v) = expected { + let actual = array.data_type(); + if v.as_ref() != actual { + return Err(DataFusionError::Internal(format!( + "Converted group rows expected dictionary of {v} got {actual}" + ))); + } + *array = cast(array.as_ref(), expected)?; + } + } + self.group_values = Some(group_values); Ok(output) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index a38044de02e3..4f37be7263f3 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -36,9 +36,8 @@ use crate::{ use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow_schema::DataType; use datafusion_common::stats::Precision; -use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; +use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_expr::Accumulator; use datafusion_physical_expr::{ @@ -254,9 +253,6 @@ pub struct AggregateExec { limit: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub input: Arc, - /// Original aggregation schema, could be different from `schema` before dictionary group - /// keys get materialized - original_schema: SchemaRef, /// Schema after the aggregate is applied schema: SchemaRef, /// Input schema before any aggregation is applied. For partial aggregate this will be the @@ -287,7 +283,7 @@ impl AggregateExec { input: Arc, input_schema: SchemaRef, ) -> Result { - let original_schema = create_schema( + let schema = create_schema( &input.schema(), &group_by.expr, &aggr_expr, @@ -295,11 +291,7 @@ impl AggregateExec { mode, )?; - let schema = Arc::new(materialize_dict_group_keys( - &original_schema, - group_by.expr.len(), - )); - let original_schema = Arc::new(original_schema); + let schema = Arc::new(schema); AggregateExec::try_new_with_schema( mode, group_by, @@ -308,7 +300,6 @@ impl AggregateExec { input, input_schema, schema, - original_schema, ) } @@ -329,8 +320,12 @@ impl AggregateExec { input: Arc, input_schema: SchemaRef, schema: SchemaRef, - original_schema: SchemaRef, ) -> Result { + // Make sure arguments are consistent in size + if aggr_expr.len() != filter_expr.len() { + return internal_err!("Inconsistent aggregate expr: {:?} and filter expr: {:?} for AggregateExec, their size should match", aggr_expr, filter_expr); + } + let input_eq_properties = input.equivalence_properties(); // Get GROUP BY expressions: let groupby_exprs = group_by.input_exprs(); @@ -382,7 +377,6 @@ impl AggregateExec { aggr_expr, filter_expr, input, - original_schema, schema, input_schema, projection_mapping, @@ -693,7 +687,7 @@ impl ExecutionPlan for AggregateExec { children[0].clone(), self.input_schema.clone(), self.schema.clone(), - self.original_schema.clone(), + //self.original_schema.clone(), )?; me.limit = self.limit; Ok(Arc::new(me)) @@ -800,24 +794,6 @@ fn create_schema( Ok(Schema::new(fields)) } -/// returns schema with dictionary group keys materialized as their value types -/// The actual convertion happens in `RowConverter` and we don't do unnecessary -/// conversion back into dictionaries -fn materialize_dict_group_keys(schema: &Schema, group_count: usize) -> Schema { - let fields = schema - .fields - .iter() - .enumerate() - .map(|(i, field)| match field.data_type() { - DataType::Dictionary(_, value_data_type) if i < group_count => { - Field::new(field.name(), *value_data_type.clone(), field.is_nullable()) - } - _ => Field::clone(field), - }) - .collect::>(); - Schema::new(fields) -} - fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { let group_fields = schema.fields()[0..group_count].to_vec(); Arc::new(Schema::new(group_fields)) @@ -1824,11 +1800,12 @@ mod tests { (1, groups_some.clone(), aggregates_v1), (2, groups_some, aggregates_v2), ] { + let n_aggr = aggregates.len(); let partial_aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Partial, groups, aggregates, - vec![None; 3], + vec![None; n_aggr], input.clone(), input_schema.clone(), )?); diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 89614fd3020c..6a0c02f5caf3 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -324,9 +324,7 @@ impl GroupedHashAggregateStream { .map(create_group_accumulator) .collect::>()?; - // we need to use original schema so RowConverter in group_values below - // will do the proper coversion of dictionaries into value types - let group_schema = group_schema(&agg.original_schema, agg_group_by.expr.len()); + let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); let spill_expr = group_schema .fields .into_iter() diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index e7c267817708..e974ffa81ccd 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -72,6 +72,7 @@ use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use ahash::RandomState; +use datafusion_physical_expr::PhysicalSortRequirement; use futures::Stream; use hashbrown::HashSet; use parking_lot::Mutex; @@ -185,6 +186,10 @@ pub struct SymmetricHashJoinExec { column_indices: Vec, /// If null_equals_null is true, null == null else null != null pub(crate) null_equals_null: bool, + /// Left side sort expression(s) + pub(crate) left_sort_exprs: Option>, + /// Right side sort expression(s) + pub(crate) right_sort_exprs: Option>, /// Partition Mode mode: StreamJoinPartitionMode, } @@ -196,6 +201,7 @@ impl SymmetricHashJoinExec { /// - It is not possible to join the left and right sides on keys `on`, or /// - It fails to construct `SortedFilterExpr`s, or /// - It fails to create the [ExprIntervalGraph]. + #[allow(clippy::too_many_arguments)] pub fn try_new( left: Arc, right: Arc, @@ -203,6 +209,8 @@ impl SymmetricHashJoinExec { filter: Option, join_type: &JoinType, null_equals_null: bool, + left_sort_exprs: Option>, + right_sort_exprs: Option>, mode: StreamJoinPartitionMode, ) -> Result { let left_schema = left.schema(); @@ -236,6 +244,8 @@ impl SymmetricHashJoinExec { metrics: ExecutionPlanMetricsSet::new(), column_indices, null_equals_null, + left_sort_exprs, + right_sort_exprs, mode, }) } @@ -275,6 +285,16 @@ impl SymmetricHashJoinExec { self.mode } + /// Get left_sort_exprs + pub fn left_sort_exprs(&self) -> Option<&[PhysicalSortExpr]> { + self.left_sort_exprs.as_deref() + } + + /// Get right_sort_exprs + pub fn right_sort_exprs(&self) -> Option<&[PhysicalSortExpr]> { + self.right_sort_exprs.as_deref() + } + /// Check if order information covers every column in the filter expression. pub fn check_if_order_information_available(&self) -> Result { if let Some(filter) = self.filter() { @@ -341,10 +361,6 @@ impl ExecutionPlan for SymmetricHashJoinExec { Ok(children.iter().any(|u| *u)) } - fn benefits_from_input_partitioning(&self) -> Vec { - vec![false, false] - } - fn required_input_distribution(&self) -> Vec { match self.mode { StreamJoinPartitionMode::Partitioned => { @@ -364,6 +380,17 @@ impl ExecutionPlan for SymmetricHashJoinExec { } } + fn required_input_ordering(&self) -> Vec>> { + vec![ + self.left_sort_exprs + .as_ref() + .map(PhysicalSortRequirement::from_sort_exprs), + self.right_sort_exprs + .as_ref() + .map(PhysicalSortRequirement::from_sort_exprs), + ] + } + fn output_partitioning(&self) -> Partitioning { let left_columns_len = self.left.schema().fields.len(); partitioned_join_output_partitioning( @@ -407,6 +434,8 @@ impl ExecutionPlan for SymmetricHashJoinExec { self.filter.clone(), &self.join_type, self.null_equals_null, + self.left_sort_exprs.clone(), + self.right_sort_exprs.clone(), self.mode, )?)) } @@ -435,24 +464,21 @@ impl ExecutionPlan for SymmetricHashJoinExec { } // If `filter_state` and `filter` are both present, then calculate sorted filter expressions // for both sides, and build an expression graph. - let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match ( - self.left.output_ordering(), - self.right.output_ordering(), - &self.filter, - ) { - (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { - let (left, right, graph) = prepare_sorted_exprs( - filter, - &self.left, - &self.right, - left_sort_exprs, - right_sort_exprs, - )?; - (Some(left), Some(right), Some(graph)) - } - // If `filter_state` or `filter` is not present, then return None for all three values: - _ => (None, None, None), - }; + let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = + match (&self.left_sort_exprs, &self.right_sort_exprs, &self.filter) { + (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { + let (left, right, graph) = prepare_sorted_exprs( + filter, + &self.left, + &self.right, + left_sort_exprs, + right_sort_exprs, + )?; + (Some(left), Some(right), Some(graph)) + } + // If `filter_state` or `filter` is not present, then return None for all three values: + _ => (None, None, None), + }; let (on_left, on_right) = self.on.iter().cloned().unzip(); diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index fbd52ddf0c70..477e2de421b9 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -90,17 +90,19 @@ pub async fn partitioned_sym_join_with_filter( let join = SymmetricHashJoinExec::try_new( Arc::new(RepartitionExec::try_new( - left, + left.clone(), Partitioning::Hash(left_expr, partition_count), )?), Arc::new(RepartitionExec::try_new( - right, + right.clone(), Partitioning::Hash(right_expr, partition_count), )?), on, filter, join_type, null_equals_null, + left.output_ordering().map(|p| p.to_vec()), + right.output_ordering().map(|p| p.to_vec()), StreamJoinPartitionMode::Partitioned, )?; diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index cae48c627f68..01d4f8941802 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -288,6 +288,24 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// [`TryStreamExt`]: futures::stream::TryStreamExt /// [`RecordBatchStreamAdapter`]: crate::stream::RecordBatchStreamAdapter /// + /// # Cancellation / Aborting Execution + /// + /// The [`Stream`] that is returned must ensure that any allocated resources + /// are freed when the stream itself is dropped. This is particularly + /// important for [`spawn`]ed tasks or threads. Unless care is taken to + /// "abort" such tasks, they may continue to consume resources even after + /// the plan is dropped, generating intermediate results that are never + /// used. + /// + /// See [`AbortOnDropSingle`], [`AbortOnDropMany`] and + /// [`RecordBatchReceiverStreamBuilder`] for structures to help ensure all + /// background tasks are cancelled. + /// + /// [`spawn`]: tokio::task::spawn + /// [`AbortOnDropSingle`]: crate::common::AbortOnDropSingle + /// [`AbortOnDropMany`]: crate::common::AbortOnDropMany + /// [`RecordBatchReceiverStreamBuilder`]: crate::stream::RecordBatchReceiverStreamBuilder + /// /// # Implementation Examples /// /// While `async` `Stream`s have a non trivial learning curve, the @@ -491,7 +509,12 @@ pub async fn collect( common::collect(stream).await } -/// Execute the [ExecutionPlan] and return a single stream of results +/// Execute the [ExecutionPlan] and return a single stream of results. +/// +/// # Aborting Execution +/// +/// Dropping the stream will abort the execution of the query, and free up +/// any allocated resources pub fn execute_stream( plan: Arc, context: Arc, @@ -549,7 +572,13 @@ pub async fn collect_partitioned( Ok(batches) } -/// Execute the [ExecutionPlan] and return a vec with one stream per output partition +/// Execute the [ExecutionPlan] and return a vec with one stream per output +/// partition +/// +/// # Aborting Execution +/// +/// Dropping the stream will abort the execution of the query, and free up +/// any allocated resources pub fn execute_stream_partitioned( plan: Arc, context: Arc, diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 37e8ffd76159..c31d5f62c726 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -877,7 +877,7 @@ mod tests { AggregateMode::Final, build_group_by(&csv.schema().clone(), vec!["i".to_string()]), vec![], - vec![None], + vec![], csv.clone(), csv.schema().clone(), )?; diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index b624fb362e65..f82f7ea2f869 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -27,9 +27,8 @@ use crate::{ PhysicalExpr, }; -use arrow::array::new_null_array; -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use datafusion_common::{internal_err, plan_err, DataFusionError, Result, ScalarValue}; use datafusion_execution::TaskContext; @@ -53,15 +52,14 @@ impl ValuesExec { } let n_row = data.len(); let n_col = schema.fields().len(); - // we have this single row, null, typed batch as a placeholder to satisfy evaluation argument - let batch = RecordBatch::try_new( - schema.clone(), - schema - .fields() - .iter() - .map(|field| new_null_array(field.data_type(), 1)) - .collect::>(), + // we have this single row batch as a placeholder to satisfy evaluation argument + // and generate a single output row + let batch = RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions::new().with_row_count(Some(1)), )?; + let arr = (0..n_col) .map(|j| { (0..n_row) @@ -71,7 +69,7 @@ impl ValuesExec { match r { Ok(ColumnarValue::Scalar(scalar)) => Ok(scalar), Ok(ColumnarValue::Array(a)) if a.len() == 1 => { - Ok(ScalarValue::List(a)) + ScalarValue::try_from_array(&a, 0) } Ok(ColumnarValue::Array(a)) => { plan_err!( @@ -174,7 +172,7 @@ impl ExecutionPlan for ValuesExec { partition: usize, _context: Arc, ) -> Result { - // GlobalLimitExec has a single output partition + // ValuesExec has a single output partition if 0 != partition { return internal_err!( "ValuesExec invalid partition {partition} (expected 0)" @@ -201,6 +199,7 @@ impl ExecutionPlan for ValuesExec { #[cfg(test)] mod tests { use super::*; + use crate::expressions::lit; use crate::test::{self, make_partition}; use arrow_schema::{DataType, Field, Schema}; @@ -240,4 +239,18 @@ mod tests { ])); let _ = ValuesExec::try_new_from_batches(invalid_schema, batches).unwrap_err(); } + + // Test issue: https://github.com/apache/arrow-datafusion/issues/8763 + #[test] + fn new_exec_with_non_nullable_schema() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col0", + DataType::UInt32, + false, + )])); + let _ = ValuesExec::try_new(schema.clone(), vec![vec![lit(1u32)]]).unwrap(); + // Test that a null value is rejected + let _ = ValuesExec::try_new(schema, vec![vec![lit(ScalarValue::UInt32(None))]]) + .unwrap_err(); + } } diff --git a/datafusion/proto/README.md b/datafusion/proto/README.md index 171aadb744d6..8d25f193fa6b 100644 --- a/datafusion/proto/README.md +++ b/datafusion/proto/README.md @@ -17,11 +17,40 @@ under the License. --> -# DataFusion Proto +# Apache Arrow DataFusion Proto -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +Apache Arrow [DataFusion][df] is an extensible query execution framework, +written in Rust, that uses Apache Arrow as its in-memory format. -This crate is a submodule of DataFusion that provides a protocol buffer format for representing query plans and expressions. +This crate provides support format for serializing and deserializing the +following structures to and from bytes: + +1. [`LogicalPlan`]'s (including [`Expr`]), +2. [`ExecutionPlan`]s (including [`PhysiscalExpr`]) + +This format can be useful for sending plans over the network, for example when +building a distributed query engine. + +Internally, this crate is implemented by converting the plans to [protocol +buffers] using [prost]. + +[protocol buffers]: https://developers.google.com/protocol-buffers +[`logicalplan`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.LogicalPlan.html +[`expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/expr/enum.Expr.html +[`executionplan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html +[`physiscalexpr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/trait.PhysicalExpr.html +[prost]: https://docs.rs/prost/latest/prost/ + +## See Also + +The binary format created by this crate supports the full range of DataFusion +plans, but is DataFusion specific. See [datafusion-substrait] which can encode +many DataFusion plans using the [substrait.io] standard. + +[datafusion-substrait]: https://docs.rs/datafusion-substrait/latest/datafusion_substrait +[substrait.io]: https://substrait.io + +# Examples ## Serializing Expressions diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d5f8397aa30c..c95465b5ae44 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -212,7 +212,8 @@ message CreateExternalTableNode { bool if_not_exists = 7; string delimiter = 8; string definition = 9; - string file_compression_type = 10; + reserved 10; // was string file_compression_type + CompressionTypeVariant file_compression_type = 17; repeated LogicalExprNodeCollection order_exprs = 13; bool unbounded = 14; map options = 11; @@ -667,6 +668,7 @@ enum ScalarFunction { FindInSet = 127; ArraySort = 128; ArrayDistinct = 129; + ArrayResize = 130; } message ScalarFunctionNode { @@ -1536,6 +1538,8 @@ message SymmetricHashJoinExecNode { StreamPartitionMode partition_mode = 6; bool null_equals_null = 7; JoinFilter filter = 8; + repeated PhysicalSortExprNode left_sort_exprs = 9; + repeated PhysicalSortExprNode right_sort_exprs = 10; } message InterleaveExecNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 12e834d75adf..d5d86b2179fa 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -4175,7 +4175,7 @@ impl serde::Serialize for CreateExternalTableNode { if !self.definition.is_empty() { len += 1; } - if !self.file_compression_type.is_empty() { + if self.file_compression_type != 0 { len += 1; } if !self.order_exprs.is_empty() { @@ -4221,8 +4221,10 @@ impl serde::Serialize for CreateExternalTableNode { if !self.definition.is_empty() { struct_ser.serialize_field("definition", &self.definition)?; } - if !self.file_compression_type.is_empty() { - struct_ser.serialize_field("fileCompressionType", &self.file_compression_type)?; + if self.file_compression_type != 0 { + let v = CompressionTypeVariant::try_from(self.file_compression_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.file_compression_type)))?; + struct_ser.serialize_field("fileCompressionType", &v)?; } if !self.order_exprs.is_empty() { struct_ser.serialize_field("orderExprs", &self.order_exprs)?; @@ -4420,7 +4422,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { if file_compression_type__.is_some() { return Err(serde::de::Error::duplicate_field("fileCompressionType")); } - file_compression_type__ = Some(map_.next_value()?); + file_compression_type__ = Some(map_.next_value::()? as i32); } GeneratedField::OrderExprs => { if order_exprs__.is_some() { @@ -22332,6 +22334,7 @@ impl serde::Serialize for ScalarFunction { Self::FindInSet => "FindInSet", Self::ArraySort => "ArraySort", Self::ArrayDistinct => "ArrayDistinct", + Self::ArrayResize => "ArrayResize", }; serializer.serialize_str(variant) } @@ -22473,6 +22476,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "FindInSet", "ArraySort", "ArrayDistinct", + "ArrayResize", ]; struct GeneratedVisitor; @@ -22643,6 +22647,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "FindInSet" => Ok(ScalarFunction::FindInSet), "ArraySort" => Ok(ScalarFunction::ArraySort), "ArrayDistinct" => Ok(ScalarFunction::ArrayDistinct), + "ArrayResize" => Ok(ScalarFunction::ArrayResize), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -25671,6 +25676,12 @@ impl serde::Serialize for SymmetricHashJoinExecNode { if self.filter.is_some() { len += 1; } + if !self.left_sort_exprs.is_empty() { + len += 1; + } + if !self.right_sort_exprs.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.SymmetricHashJoinExecNode", len)?; if let Some(v) = self.left.as_ref() { struct_ser.serialize_field("left", v)?; @@ -25697,6 +25708,12 @@ impl serde::Serialize for SymmetricHashJoinExecNode { if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; } + if !self.left_sort_exprs.is_empty() { + struct_ser.serialize_field("leftSortExprs", &self.left_sort_exprs)?; + } + if !self.right_sort_exprs.is_empty() { + struct_ser.serialize_field("rightSortExprs", &self.right_sort_exprs)?; + } struct_ser.end() } } @@ -25717,6 +25734,10 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { "null_equals_null", "nullEqualsNull", "filter", + "left_sort_exprs", + "leftSortExprs", + "right_sort_exprs", + "rightSortExprs", ]; #[allow(clippy::enum_variant_names)] @@ -25728,6 +25749,8 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { PartitionMode, NullEqualsNull, Filter, + LeftSortExprs, + RightSortExprs, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -25756,6 +25779,8 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), "filter" => Ok(GeneratedField::Filter), + "leftSortExprs" | "left_sort_exprs" => Ok(GeneratedField::LeftSortExprs), + "rightSortExprs" | "right_sort_exprs" => Ok(GeneratedField::RightSortExprs), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -25782,6 +25807,8 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { let mut partition_mode__ = None; let mut null_equals_null__ = None; let mut filter__ = None; + let mut left_sort_exprs__ = None; + let mut right_sort_exprs__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Left => { @@ -25826,6 +25853,18 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { } filter__ = map_.next_value()?; } + GeneratedField::LeftSortExprs => { + if left_sort_exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("leftSortExprs")); + } + left_sort_exprs__ = Some(map_.next_value()?); + } + GeneratedField::RightSortExprs => { + if right_sort_exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("rightSortExprs")); + } + right_sort_exprs__ = Some(map_.next_value()?); + } } } Ok(SymmetricHashJoinExecNode { @@ -25836,6 +25875,8 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { partition_mode: partition_mode__.unwrap_or_default(), null_equals_null: null_equals_null__.unwrap_or_default(), filter: filter__, + left_sort_exprs: left_sort_exprs__.unwrap_or_default(), + right_sort_exprs: right_sort_exprs__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 4ee0b70325ca..7e262e620fa7 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -349,8 +349,8 @@ pub struct CreateExternalTableNode { pub delimiter: ::prost::alloc::string::String, #[prost(string, tag = "9")] pub definition: ::prost::alloc::string::String, - #[prost(string, tag = "10")] - pub file_compression_type: ::prost::alloc::string::String, + #[prost(enumeration = "CompressionTypeVariant", tag = "17")] + pub file_compression_type: i32, #[prost(message, repeated, tag = "13")] pub order_exprs: ::prost::alloc::vec::Vec, #[prost(bool, tag = "14")] @@ -2178,6 +2178,10 @@ pub struct SymmetricHashJoinExecNode { pub null_equals_null: bool, #[prost(message, optional, tag = "8")] pub filter: ::core::option::Option, + #[prost(message, repeated, tag = "9")] + pub left_sort_exprs: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "10")] + pub right_sort_exprs: ::prost::alloc::vec::Vec, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -2754,6 +2758,7 @@ pub enum ScalarFunction { FindInSet = 127, ArraySort = 128, ArrayDistinct = 129, + ArrayResize = 130, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2892,6 +2897,7 @@ impl ScalarFunction { ScalarFunction::FindInSet => "FindInSet", ScalarFunction::ArraySort => "ArraySort", ScalarFunction::ArrayDistinct => "ArrayDistinct", + ScalarFunction::ArrayResize => "ArrayResize", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -3027,6 +3033,7 @@ impl ScalarFunction { "FindInSet" => Some(Self::FindInSet), "ArraySort" => Some(Self::ArraySort), "ArrayDistinct" => Some(Self::ArrayDistinct), + "ArrayResize" => Some(Self::ArrayResize), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 36c5b44f00b9..c11599412d94 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -27,6 +27,7 @@ use crate::protobuf::{ OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; use arrow::{ + array::AsArray, buffer::Buffer, datatypes::{ i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, @@ -506,6 +507,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayToString => Self::ArrayToString, ScalarFunction::ArrayIntersect => Self::ArrayIntersect, ScalarFunction::ArrayUnion => Self::ArrayUnion, + ScalarFunction::ArrayResize => Self::ArrayResize, ScalarFunction::Range => Self::Range, ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, @@ -722,9 +724,15 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; let arr = record_batch.column(0); match value { - Value::ListValue(_) => Self::List(arr.to_owned()), - Value::LargeListValue(_) => Self::LargeList(arr.to_owned()), - Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.to_owned()), + Value::ListValue(_) => { + Self::List(arr.as_list::().to_owned().into()) + } + Value::LargeListValue(_) => { + Self::LargeList(arr.as_list::().to_owned().into()) + } + Value::FixedSizeListValue(_) => { + Self::FixedSizeList(arr.as_fixed_size_list().to_owned().into()) + } _ => unreachable!(), } } @@ -1492,6 +1500,11 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), + ScalarFunction::ArrayResize => Ok(array_slice( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry)?)), ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry)?)), ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry)?)), diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index e8a38784481b..6ca95519a9b1 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -538,6 +538,16 @@ impl AsLogicalPlan for LogicalPlanNode { column_defaults.insert(col_name.clone(), expr); } + let file_compression_type = protobuf::CompressionTypeVariant::try_from( + create_extern_table.file_compression_type, + ) + .map_err(|_| { + proto_error(format!( + "Unknown file compression type {}", + create_extern_table.file_compression_type + )) + })?; + Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable(CreateExternalTable { schema: pb_schema.try_into()?, name: from_owned_table_reference(create_extern_table.name.as_ref(), "CreateExternalTable")?, @@ -552,7 +562,7 @@ impl AsLogicalPlan for LogicalPlanNode { .clone(), order_exprs, if_not_exists: create_extern_table.if_not_exists, - file_compression_type: CompressionTypeVariant::from_str(&create_extern_table.file_compression_type).map_err(|_| DataFusionError::NotImplemented(format!("Unsupported file compression type {}", create_extern_table.file_compression_type)))?, + file_compression_type: file_compression_type.into(), definition, unbounded: create_extern_table.unbounded, options: create_extern_table.options.clone(), @@ -1410,6 +1420,9 @@ impl AsLogicalPlan for LogicalPlanNode { converted_column_defaults.insert(col_name.clone(), expr.try_into()?); } + let file_compression_type = + protobuf::CompressionTypeVariant::from(file_compression_type); + Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateExternalTable( protobuf::CreateExternalTableNode { @@ -1423,7 +1436,7 @@ impl AsLogicalPlan for LogicalPlanNode { delimiter: String::from(*delimiter), order_exprs: converted_order_exprs, definition: definition.clone().unwrap_or_default(), - file_compression_type: file_compression_type.to_string(), + file_compression_type: file_compression_type.into(), unbounded: *unbounded, options: options.clone(), constraints: Some(constraints.clone().into()), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index a162b2389cd1..ec9b886c1f22 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -32,6 +32,7 @@ use crate::protobuf::{ OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; use arrow::{ + array::ArrayRef, datatypes::{ DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionMode, @@ -1159,54 +1160,15 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { } // ScalarValue::List and ScalarValue::FixedSizeList are serialized using // Arrow IPC messages as a single column RecordBatch - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => { + ScalarValue::List(arr) => { + encode_scalar_list_value(arr.to_owned() as ArrayRef, val) + } + ScalarValue::LargeList(arr) => { // Wrap in a "field_name" column - let batch = RecordBatch::try_from_iter(vec![( - "field_name", - arr.to_owned(), - )]) - .map_err(|e| { - Error::General( format!("Error creating temporary batch while encoding ScalarValue::List: {e}")) - })?; - - let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false); - let (_, encoded_message) = gen - .encoded_batch(&batch, &mut dict_tracker, &Default::default()) - .map_err(|e| { - Error::General(format!( - "Error encoding ScalarValue::List as IPC: {e}" - )) - })?; - - let schema: protobuf::Schema = batch.schema().try_into()?; - - let scalar_list_value = protobuf::ScalarListValue { - ipc_message: encoded_message.ipc_message, - arrow_data: encoded_message.arrow_data, - schema: Some(schema), - }; - - match val { - ScalarValue::List(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::ListValue( - scalar_list_value, - )), - }), - ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::LargeListValue( - scalar_list_value, - )), - }), - ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::FixedSizeListValue( - scalar_list_value, - )), - }), - _ => unreachable!(), - } + encode_scalar_list_value(arr.to_owned() as ArrayRef, val) + } + ScalarValue::FixedSizeList(arr) => { + encode_scalar_list_value(arr.to_owned() as ArrayRef, val) } ScalarValue::Date32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) @@ -1523,6 +1485,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayPositions => Self::ArrayPositions, BuiltinScalarFunction::ArrayPrepend => Self::ArrayPrepend, BuiltinScalarFunction::ArrayRepeat => Self::ArrayRepeat, + BuiltinScalarFunction::ArrayResize => Self::ArrayResize, BuiltinScalarFunction::ArrayRemove => Self::ArrayRemove, BuiltinScalarFunction::ArrayRemoveN => Self::ArrayRemoveN, BuiltinScalarFunction::ArrayRemoveAll => Self::ArrayRemoveAll, @@ -1723,3 +1686,47 @@ fn create_proto_scalar protobuf::scalar_value::Value>( Ok(protobuf::ScalarValue { value: Some(value) }) } + +fn encode_scalar_list_value( + arr: ArrayRef, + val: &ScalarValue, +) -> Result { + let batch = RecordBatch::try_from_iter(vec![("field_name", arr)]).map_err(|e| { + Error::General(format!( + "Error creating temporary batch while encoding ScalarValue::List: {e}" + )) + })?; + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + let (_, encoded_message) = gen + .encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .map_err(|e| { + Error::General(format!("Error encoding ScalarValue::List as IPC: {e}")) + })?; + + let schema: protobuf::Schema = batch.schema().try_into()?; + + let scalar_list_value = protobuf::ScalarListValue { + ipc_message: encoded_message.ipc_message, + arrow_data: encoded_message.arrow_data, + schema: Some(schema), + }; + + match val { + ScalarValue::List(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::ListValue(scalar_list_value)), + }), + ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::LargeListValue( + scalar_list_value, + )), + }), + ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::FixedSizeListValue( + scalar_list_value, + )), + }), + _ => unreachable!(), + } +} diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 23ab813ca739..ea28eeee8810 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -93,6 +93,36 @@ pub fn parse_physical_sort_expr( } } +/// Parses a physical sort expressions from a protobuf. +/// +/// # Arguments +/// +/// * `proto` - Input proto with vector of physical sort expression node +/// * `registry` - A registry knows how to build logical expressions out of user-defined function' names +/// * `input_schema` - The Arrow schema for the input, used for determining expression data types +/// when performing type coercion. +pub fn parse_physical_sort_exprs( + proto: &[protobuf::PhysicalSortExprNode], + registry: &dyn FunctionRegistry, + input_schema: &Schema, +) -> Result> { + proto + .iter() + .map(|sort_expr| { + if let Some(expr) = &sort_expr.expr { + let expr = parse_physical_expr(expr.as_ref(), registry, input_schema)?; + let options = SortOptions { + descending: !sort_expr.asc, + nulls_first: sort_expr.nulls_first, + }; + Ok(PhysicalSortExpr { expr, options }) + } else { + Err(proto_error("Unexpected empty physical expression")) + } + }) + .collect::>>() +} + /// Parses a physical window expr from a protobuf. /// /// # Arguments @@ -782,6 +812,18 @@ impl From for CompressionTypeVariant { } } +impl From for protobuf::CompressionTypeVariant { + fn from(value: CompressionTypeVariant) -> Self { + match value { + CompressionTypeVariant::GZIP => Self::Gzip, + CompressionTypeVariant::BZIP2 => Self::Bzip2, + CompressionTypeVariant::XZ => Self::Xz, + CompressionTypeVariant::ZSTD => Self::Zstd, + CompressionTypeVariant::UNCOMPRESSED => Self::Uncompressed, + } + } +} + impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { type Error = DataFusionError; diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 95becb3fe4b3..f39f885b7838 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -65,7 +65,8 @@ use prost::Message; use crate::common::str_to_byte; use crate::common::{byte_to_string, proto_error}; use crate::physical_plan::from_proto::{ - parse_physical_expr, parse_physical_sort_expr, parse_protobuf_file_scan_config, + parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs, + parse_protobuf_file_scan_config, }; use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; use crate::protobuf::physical_expr_node::ExprType; @@ -646,6 +647,30 @@ impl AsExecutionPlan for PhysicalPlanNode { }) .map_or(Ok(None), |v: Result| v.map(Some))?; + let left_schema = left.schema(); + let left_sort_exprs = parse_physical_sort_exprs( + &sym_join.left_sort_exprs, + registry, + &left_schema, + )?; + let left_sort_exprs = if left_sort_exprs.is_empty() { + None + } else { + Some(left_sort_exprs) + }; + + let right_schema = right.schema(); + let right_sort_exprs = parse_physical_sort_exprs( + &sym_join.right_sort_exprs, + registry, + &right_schema, + )?; + let right_sort_exprs = if right_sort_exprs.is_empty() { + None + } else { + Some(right_sort_exprs) + }; + let partition_mode = protobuf::StreamPartitionMode::try_from(sym_join.partition_mode).map_err(|_| { proto_error(format!( @@ -668,6 +693,8 @@ impl AsExecutionPlan for PhysicalPlanNode { filter, &join_type.into(), sym_join.null_equals_null, + left_sort_exprs, + right_sort_exprs, partition_mode, ) .map(|e| Arc::new(e) as _) @@ -1233,6 +1260,40 @@ impl AsExecutionPlan for PhysicalPlanNode { } }; + let left_sort_exprs = exec + .left_sort_exprs() + .map(|exprs| { + exprs + .iter() + .map(|expr| { + Ok(protobuf::PhysicalSortExprNode { + expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + asc: !expr.options.descending, + nulls_first: expr.options.nulls_first, + }) + }) + .collect::>>() + }) + .transpose()? + .unwrap_or(vec![]); + + let right_sort_exprs = exec + .right_sort_exprs() + .map(|exprs| { + exprs + .iter() + .map(|expr| { + Ok(protobuf::PhysicalSortExprNode { + expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + asc: !expr.options.descending, + nulls_first: expr.options.nulls_first, + }) + }) + .collect::>>() + }) + .transpose()? + .unwrap_or(vec![]); + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::SymmetricHashJoin(Box::new( protobuf::SymmetricHashJoinExecNode { @@ -1242,6 +1303,8 @@ impl AsExecutionPlan for PhysicalPlanNode { join_type: join_type.into(), partition_mode: partition_mode.into(), null_equals_null: exec.null_equals_null(), + left_sort_exprs, + right_sort_exprs, filter, }, ))), diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 402781e17e6f..03daf535f201 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1787,6 +1787,7 @@ fn roundtrip_window() { } } + #[derive(Debug, Clone)] struct SimpleWindowUDF { signature: Signature, } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 27ac5d122f83..9ee8d0d51d96 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -73,8 +73,8 @@ use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{FileTypeWriterOptions, Result}; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature, - StateTypeFunction, WindowFrame, WindowFrameBound, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, Signature, SimpleAggregateUDF, + WindowFrame, WindowFrameBound, }; use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; @@ -374,18 +374,17 @@ fn roundtrip_aggregate_udaf() -> Result<()> { } } - let rt_func: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Int64))); + let return_type = DataType::Int64; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::new(Example))); - let st_func: StateTypeFunction = - Arc::new(move |_| Ok(Arc::new(vec![DataType::Int64]))); + let state_type = vec![DataType::Int64]; - let udaf = AggregateUDF::new( + let udaf = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( "example", - &Signature::exact(vec![DataType::Int64], Volatility::Immutable), - &rt_func, - &accumulator, - &st_func, - ); + Signature::exact(vec![DataType::Int64], Volatility::Immutable), + return_type, + accumulator, + state_type, + )); let ctx = SessionContext::new(); ctx.register_udaf(udaf.clone()); @@ -904,17 +903,35 @@ fn roundtrip_sym_hash_join() -> Result<()> { StreamJoinPartitionMode::Partitioned, StreamJoinPartitionMode::SinglePartition, ] { - roundtrip_test(Arc::new( - datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new( - Arc::new(EmptyExec::new(schema_left.clone())), - Arc::new(EmptyExec::new(schema_right.clone())), - on.clone(), + for left_order in &[ + None, + Some(vec![PhysicalSortExpr { + expr: Arc::new(Column::new("col", schema_left.index_of("col")?)), + options: Default::default(), + }]), + ] { + for right_order in &[ None, - join_type, - false, - *partition_mode, - )?, - ))?; + Some(vec![PhysicalSortExpr { + expr: Arc::new(Column::new("col", schema_right.index_of("col")?)), + options: Default::default(), + }]), + ] { + roundtrip_test(Arc::new( + datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new( + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), + on.clone(), + None, + join_type, + false, + left_order.clone(), + right_order.clone(), + *partition_mode, + )?, + ))?; + } + } } } Ok(()) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 27351e10eb34..9fded63af3fc 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -98,11 +98,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { StackEntry::Operator(op) => { let right = eval_stack.pop().unwrap(); let left = eval_stack.pop().unwrap(); + let expr = Expr::BinaryExpr(BinaryExpr::new( Box::new(left), op, Box::new(right), )); + eval_stack.push(expr); } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index c5c30e3a2253..a04df5589b85 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -250,7 +250,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Default expressions are restricted, column references are not allowed let empty_schema = DFSchema::empty(); let error_desc = |e: DataFusionError| match e { - DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }) => { + DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }, _) => { plan_datafusion_err!( "Column reference is not allowed in the DEFAULT expression : {}", e diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index b96553ffbf86..b9fb4c65dc2c 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -31,9 +31,10 @@ use arrow_schema::DataType; use datafusion_common::file_options::StatementOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - not_impl_err, plan_datafusion_err, plan_err, unqualified_field_not_found, Column, - Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, - Result, ScalarValue, SchemaReference, TableReference, ToDFSchema, + not_impl_err, plan_datafusion_err, plan_err, schema_err, unqualified_field_not_found, + Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, + OwnedTableReference, Result, ScalarValue, SchemaError, SchemaReference, + TableReference, ToDFSchema, }; use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check; @@ -1138,11 +1139,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .index_of_column_by_name(None, &c)? .ok_or_else(|| unqualified_field_not_found(&c, &table_schema))?; if value_indices[column_index].is_some() { - return Err(DataFusionError::SchemaError( - datafusion_common::SchemaError::DuplicateUnqualifiedField { - name: c, - }, - )); + return schema_err!(SchemaError::DuplicateUnqualifiedField { + name: c, + }); } else { value_indices[column_index] = Some(i); } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 48ba50145308..4de08a7124cf 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -756,9 +756,11 @@ fn join_with_ambiguous_column() { #[test] fn where_selection_with_ambiguous_column() { let sql = "SELECT * FROM person a, person b WHERE id = id + 1"; - let err = logical_plan(sql).expect_err("query should have failed"); + let err = logical_plan(sql) + .expect_err("query should have failed") + .strip_backtrace(); assert_eq!( - "SchemaError(AmbiguousReference { field: Column { relation: None, name: \"id\" } })", + "\"Schema error: Ambiguous reference to unqualified field id\"", format!("{err:?}") ); } diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index e333dc816f66..7085e1ada09a 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -36,6 +36,7 @@ async-trait = { workspace = true } bigdecimal = { workspace = true } bytes = { version = "1.4.0", optional = true } chrono = { workspace = true, optional = true } +clap = { version = "4.4.8", features = ["derive", "env"] } datafusion = { path = "../core", version = "34.0.0" } datafusion-common = { workspace = true } futures = { version = "0.3.28" } diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index aeb1cc4ec919..ffae144eae84 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -21,6 +21,7 @@ use std::path::{Path, PathBuf}; #[cfg(target_family = "windows")] use std::thread; +use clap::Parser; use datafusion_sqllogictest::{DataFusion, TestContext}; use futures::stream::StreamExt; use log::info; @@ -77,7 +78,8 @@ async fn run_tests() -> Result<()> { // Enable logging (e.g. set RUST_LOG=debug to see debug logs) env_logger::init(); - let options = Options::new(); + let options: Options = clap::Parser::parse(); + options.warn_on_ignored(); // Run all tests in parallel, reporting failures at the end // @@ -88,7 +90,7 @@ async fn run_tests() -> Result<()> { .map(|test_file| { tokio::task::spawn(async move { println!("Running {:?}", test_file.relative_path); - if options.complete_mode { + if options.complete { run_complete_file(test_file).await?; } else if options.postgres_runner { run_test_file_with_postgres(test_file).await?; @@ -289,49 +291,54 @@ fn read_dir_recursive_impl(dst: &mut Vec, path: &Path) -> Result<()> { } /// Parsed command line options +/// +/// This structure attempts to mimic the command line options +/// accepted by IDEs such as CLion that pass arguments +/// +/// See for more details +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about= None)] struct Options { - // regex like - /// arguments passed to the program which are treated as - /// cargo test filter (substring match on filenames) - filters: Vec, - - /// Auto complete mode to fill out expected results - complete_mode: bool, - - /// Run Postgres compatibility tests with Postgres runner + #[clap(long, help = "Auto complete mode to fill out expected results")] + complete: bool, + + #[clap( + long, + env = "PG_COMPAT", + help = "Run Postgres compatibility tests with Postgres runner" + )] postgres_runner: bool, - /// Include tpch files + #[clap(long, env = "INCLUDE_TPCH", help = "Include tpch files")] include_tpch: bool, -} -impl Options { - fn new() -> Self { - let args: Vec<_> = std::env::args().collect(); - - let complete_mode = args.iter().any(|a| a == "--complete"); - let postgres_runner = std::env::var("PG_COMPAT").map_or(false, |_| true); - let include_tpch = std::env::var("INCLUDE_TPCH").map_or(false, |_| true); - - // treat args after the first as filters to run (substring matching) - let filters = if !args.is_empty() { - args.into_iter() - .skip(1) - // ignore command line arguments like `--complete` - .filter(|arg| !arg.as_str().starts_with("--")) - .collect::>() - } else { - vec![] - }; + #[clap( + action, + help = "regex like arguments passed to the program which are treated as cargo test filter (substring match on filenames)" + )] + filters: Vec, - Self { - filters, - complete_mode, - postgres_runner, - include_tpch, - } - } + #[clap( + long, + help = "IGNORED (for compatibility with built in rust test runner)" + )] + format: Option, + + #[clap( + short = 'Z', + long, + help = "IGNORED (for compatibility with built in rust test runner)" + )] + z_options: Option, + + #[clap( + long, + help = "IGNORED (for compatibility with built in rust test runner)" + )] + show_output: bool, +} +impl Options { /// Because this test can be run as a cargo test, commands like /// /// ```shell @@ -359,4 +366,19 @@ impl Options { let file_name = path.file_name().unwrap().to_str().unwrap().to_string(); !self.postgres_runner || file_name.starts_with(PG_COMPAT_FILE_PREFIX) } + + /// Logs warning messages to stdout if any ignored options are passed + fn warn_on_ignored(&self) { + if self.format.is_some() { + println!("WARNING: Ignoring `--format` compatibility option"); + } + + if self.z_options.is_some() { + println!("WARNING: Ignoring `-Z` compatibility option"); + } + + if self.show_output { + println!("WARNING: Ignoring `--show-output` compatibility option"); + } + } } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 78575c9dffc5..aa512f6e2600 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2469,11 +2469,11 @@ select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict); query T select arrow_typeof(x_dict) from value_dict group by x_dict; ---- -Int32 -Int32 -Int32 -Int32 -Int32 +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) statement ok drop table value diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 7cee615a5729..6b45f204fefc 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -953,7 +953,7 @@ select array_element(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array ---- [1, 2, 3, 4, 5] -# array_extract scalar function #8 (function alias `array_slice`) +# array_extract scalar function #8 (function alias `array_element`) query IT select array_extract(make_array(1, 2, 3, 4, 5), 2), array_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- @@ -964,7 +964,7 @@ select array_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), ---- 2 l -# list_element scalar function #9 (function alias `array_slice`) +# list_element scalar function #9 (function alias `array_element`) query IT select list_element(make_array(1, 2, 3, 4, 5), 2), list_element(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- @@ -975,7 +975,7 @@ select list_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2 ---- 2 l -# list_extract scalar function #10 (function alias `array_slice`) +# list_extract scalar function #10 (function alias `array_element`) query IT select list_extract(make_array(1, 2, 3, 4, 5), 2), list_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- @@ -3238,30 +3238,55 @@ select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, ---- h,e,l,l,o 1-2-3-4-5 1|2|3 +query TTT +select list_to_string(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), list_to_string(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), list_to_string(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + # array_join scalar function #5 (function alias `array_to_string`) query TTT select array_join(['h', 'e', 'l', 'l', 'o'], ','), array_join([1, 2, 3, 4, 5], '-'), array_join([1.0, 2.0, 3.0], '|'); ---- h,e,l,l,o 1-2-3-4-5 1|2|3 +query TTT +select array_join(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), array_join(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), array_join(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + # list_join scalar function #6 (function alias `list_join`) query TTT select list_join(['h', 'e', 'l', 'l', 'o'], ','), list_join([1, 2, 3, 4, 5], '-'), list_join([1.0, 2.0, 3.0], '|'); ---- h,e,l,l,o 1-2-3-4-5 1|2|3 +query TTT +select list_join(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), list_join(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), list_join(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + # array_to_string scalar function with nulls #1 query TTT select array_to_string(make_array('h', NULL, 'l', NULL, 'o'), ','), array_to_string(make_array(1, NULL, 3, NULL, 5), '-'), array_to_string(make_array(NULL, 2.0, 3.0), '|'); ---- h,l,o 1-3-5 2|3 +query TTT +select array_to_string(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), array_to_string(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), array_to_string(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + # array_to_string scalar function with nulls #2 query TTT select array_to_string(make_array('h', NULL, NULL, NULL, 'o'), ',', '-'), array_to_string(make_array(NULL, 2, NULL, 4, 5), '-', 'nil'), array_to_string(make_array(1.0, NULL, 3.0), '|', '0'); ---- h,-,-,-,o nil-2-nil-4-5 1|0|3 +query TTT +select array_to_string(arrow_cast(make_array('h', NULL, NULL, NULL, 'o'), 'LargeList(Utf8)'), ',', '-'), array_to_string(arrow_cast(make_array(NULL, 2, NULL, 4, 5), 'LargeList(Int64)'), '-', 'nil'), array_to_string(arrow_cast(make_array(1.0, NULL, 3.0), 'LargeList(Float64)'), '|', '0'); +---- +h,-,-,-,o nil-2-nil-4-5 1|0|3 + # array_to_string with columns #1 # For reference @@ -3288,6 +3313,18 @@ NULL 51^52^54^55^56^57^58^59^60 NULL +query T +select array_to_string(column1, column4) from large_arrays_values; +---- +2,3,4,5,6,7,8,9,10 +11.12.13.14.15.16.17.18.20 +21-22-23-25-26-27-28-29-30 +31ok32ok33ok34ok35ok37ok38ok39ok40 +NULL +41$42$43$44$45$46$47$48$49$50 +51^52^54^55^56^57^58^59^60 +NULL + query TT select array_to_string(column1, '_'), array_to_string(make_array(1,2,3), '/') from arrays_values; ---- @@ -3300,6 +3337,18 @@ NULL 1/2/3 51_52_54_55_56_57_58_59_60 1/2/3 61_62_63_64_65_66_67_68_69_70 1/2/3 +query TT +select array_to_string(column1, '_'), array_to_string(make_array(1,2,3), '/') from large_arrays_values; +---- +2_3_4_5_6_7_8_9_10 1/2/3 +11_12_13_14_15_16_17_18_20 1/2/3 +21_22_23_25_26_27_28_29_30 1/2/3 +31_32_33_34_35_37_38_39_40 1/2/3 +NULL 1/2/3 +41_42_43_44_45_46_47_48_49_50 1/2/3 +51_52_54_55_56_57_58_59_60 1/2/3 +61_62_63_64_65_66_67_68_69_70 1/2/3 + query TT select array_to_string(column1, '_', '*'), array_to_string(make_array(make_array(1,2,3)), '.') from arrays_values; ---- @@ -3312,6 +3361,18 @@ NULL 1.2.3 51_52_*_54_55_56_57_58_59_60 1.2.3 61_62_63_64_65_66_67_68_69_70 1.2.3 +query TT +select array_to_string(column1, '_', '*'), array_to_string(make_array(make_array(1,2,3)), '.') from large_arrays_values; +---- +*_2_3_4_5_6_7_8_9_10 1.2.3 +11_12_13_14_15_16_17_18_*_20 1.2.3 +21_22_23_*_25_26_27_28_29_30 1.2.3 +31_32_33_34_35_*_37_38_39_40 1.2.3 +NULL 1.2.3 +41_42_43_44_45_46_47_48_49_50 1.2.3 +51_52_*_54_55_56_57_58_59_60 1.2.3 +61_62_63_64_65_66_67_68_69_70 1.2.3 + ## cardinality # cardinality scalar function @@ -3320,18 +3381,33 @@ select cardinality(make_array(1, 2, 3, 4, 5)), cardinality([1, 3, 5]), cardinali ---- 5 3 5 +query III +select cardinality(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), cardinality(arrow_cast([1, 3, 5], 'LargeList(Int64)')), cardinality(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +5 3 5 + # cardinality scalar function #2 query II select cardinality(make_array([1, 2], [3, 4], [5, 6])), cardinality(array_repeat(array_repeat(array_repeat(3, 3), 2), 3)); ---- 6 18 +query I +select cardinality(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); +---- +6 + # cardinality scalar function #3 query II select cardinality(make_array()), cardinality(make_array(make_array())) ---- NULL 0 +query II +select cardinality(arrow_cast(make_array(), 'LargeList(Null)')), cardinality(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +---- +NULL 0 + # cardinality with columns query III select cardinality(column1), cardinality(column2), cardinality(column3) from arrays; @@ -3344,6 +3420,17 @@ NULL 3 4 4 NULL 1 4 3 NULL +query III +select cardinality(column1), cardinality(column2), cardinality(column3) from large_arrays; +---- +4 3 5 +4 3 5 +4 3 5 +4 3 3 +NULL 3 4 +4 NULL 1 +4 3 NULL + ## array_remove (aliases: `list_remove`) # array_remove scalar function #1 @@ -4530,6 +4617,45 @@ select 1 || make_array(2, 3, 4), 1.0 || make_array(2.0, 3.0, 4.0), 'h' || make_a ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +# array concatenate operator with scalars #4 (mixed) +query ? +select 0 || [1,2,3] || 4 || [5] || [6,7]; +---- +[0, 1, 2, 3, 4, 5, 6, 7] + +# array concatenate operator with nd-list #5 (mixed) +query ? +select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10]; +---- +[[0, 1, 2, 3], [4, 5], [6, 7, 8], [9, 10]] + +# array concatenate operator non-valid cases +## concat 2D with scalar is not valid +query error +select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10] || 11; + +## concat scalar with 2D is not valid +query error +select 0 || [[1,2,3]]; + +# array concatenate operator with column + +statement ok +CREATE TABLE array_concat_operator_table +AS VALUES + (0, [1, 2, 2, 3], 4, [5, 6, 5]), + (-1, [4, 5, 6], 7, [8, 1, 1]) +; + +query ? +select column1 || column2 || column3 || column4 from array_concat_operator_table; +---- +[0, 1, 2, 2, 3, 4, 5, 6, 5] +[-1, 4, 5, 6, 7, 8, 1, 1] + +statement ok +drop table array_concat_operator_table; + ## array containment operator # array containment operator with scalars #1 (at arrow) @@ -4768,6 +4894,90 @@ select string_to_list(e, 'm') from values; [adipiscing] NULL +# array_resize scalar function #1 +query ? +select array_resize(make_array(1, 2, 3), 1); +---- +[1] + +query ? +select array_resize(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 1); +---- +[1] + +# array_resize scalar function #2 +query ? +select array_resize(make_array(1, 2, 3), 5); +---- +[1, 2, 3, , ] + +query ? +select array_resize(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 5); +---- +[1, 2, 3, , ] + +# array_resize scalar function #3 +query ? +select array_resize(make_array(1, 2, 3), 5, 4); +---- +[1, 2, 3, 4, 4] + +query ? +select array_resize(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 5, 4); +---- +[1, 2, 3, 4, 4] + +# array_resize scalar function #4 +query error +select array_resize(make_array(1, 2, 3), -5, 2); + +# array_resize scalar function #5 +query ? +select array_resize(make_array(1.1, 2.2, 3.3), 10, 9.9); +---- +[1.1, 2.2, 3.3, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9] + +query ? +select array_resize(arrow_cast(make_array(1.1, 2.2, 3.3), 'LargeList(Float64)'), 10, 9.9); +---- +[1.1, 2.2, 3.3, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9] + +# array_resize scalar function #5 +query ? +select array_resize(column1, column2, column3) from arrays_values; +---- +[] +[11, 12, 13, 14, 15, 16, 17, 18, , 20, 2, 2] +[21, 22, 23, , 25, 26, 27, 28, 29, 30, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] +[31, 32, 33, 34, 35, , 37, 38, 39, 40, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4] +[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5] +[] +[51, 52, , 54, 55, 56, 57, 58, 59, 60, , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , ] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7] + +query ? +select array_resize(arrow_cast(column1, 'LargeList(Int64)'), column2, column3) from arrays_values; +---- +[] +[11, 12, 13, 14, 15, 16, 17, 18, , 20, 2, 2] +[21, 22, 23, , 25, 26, 27, 28, 29, 30, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] +[31, 32, 33, 34, 35, , 37, 38, 39, 40, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4] +[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5] +[] +[51, 52, , 54, 55, 56, 57, 58, 59, 60, , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , ] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7] + +# array_resize scalar function #5 +query ? +select array_resize([[1], [2], [3]], 10, [5]); +---- +[[1], [2], [3], [5], [5], [5], [5], [5], [5], [5]] + +query ? +select array_resize(arrow_cast([[1], [2], [3]], 'LargeList(List(Int64))'), 10, [5]); +---- +[[1], [2], [3], [5], [5], [5], [5], [5], [5], [5]] + ### Delete tables statement ok diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 3fad4d0f61b9..6a623e6c92f9 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -375,4 +375,4 @@ select arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'); query T select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')); ---- -LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) \ No newline at end of file +LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt new file mode 100644 index 000000000000..002aade2528e --- /dev/null +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for querying on dictionary encoded data + +# Note: These tables model data as is common for timeseries, such as in InfluxDB IOx +# There are three types of columns: +# 1. tag columns, which are string dictionaries, often with low cardinality +# 2. field columns, which are typed, +# 3. a `time` columns, which is a nanosecond timestamp + +# It is common to group and filter on the "tag" columns (and thus on dictionary +# encoded values) + +# Table m1 with a tag column `tag_id` 4 fields `f1` - `f4`, and `time` + +statement ok +CREATE VIEW m1 AS +SELECT + arrow_cast(column1, 'Dictionary(Int32, Utf8)') as tag_id, + arrow_cast(column2, 'Float64') as f1, + arrow_cast(column3, 'Utf8') as f2, + arrow_cast(column4, 'Utf8') as f3, + arrow_cast(column5, 'Float64') as f4, + arrow_cast(column6, 'Timestamp(Nanosecond, None)') as time +FROM ( + VALUES + -- equivalent to the following line protocol data + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=1.0 1703030400000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=2.0 1703031000000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=3.0 1703031600000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=4.0 1703032200000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=5.0 1703032800000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=6.0 1703033400000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=7.0 1703034000000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=8.0 1703034600000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=9.0 1703035200000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=10.0 1703035800000000000 + ('1000', 32, 'foo', 'True', 1.0, 1703030400000000000), + ('1000', 32, 'foo', 'True', 2.0, 1703031000000000000), + ('1000', 32, 'foo', 'True', 3.0, 1703031600000000000), + ('1000', 32, 'foo', 'True', 4.0, 1703032200000000000), + ('1000', 32, 'foo', 'True', 5.0, 1703032800000000000), + ('1000', 32, 'foo', 'True', 6.0, 1703033400000000000), + ('1000', 32, 'foo', 'True', 7.0, 1703034000000000000), + ('1000', 32, 'foo', 'True', 8.0, 1703034600000000000), + ('1000', 32, 'foo', 'True', 9.0, 1703035200000000000), + ('1000', 32, 'foo', 'True', 10.0, 1703035800000000000) +); + +query ?RTTRP +SELECT * FROM m1; +---- +1000 32 foo True 1 2023-12-20T00:00:00 +1000 32 foo True 2 2023-12-20T00:10:00 +1000 32 foo True 3 2023-12-20T00:20:00 +1000 32 foo True 4 2023-12-20T00:30:00 +1000 32 foo True 5 2023-12-20T00:40:00 +1000 32 foo True 6 2023-12-20T00:50:00 +1000 32 foo True 7 2023-12-20T01:00:00 +1000 32 foo True 8 2023-12-20T01:10:00 +1000 32 foo True 9 2023-12-20T01:20:00 +1000 32 foo True 10 2023-12-20T01:30:00 + +# Note that te type of the tag column is `Dictionary(Int32, Utf8)` +query TTT +DESCRIBE m1; +---- +tag_id Dictionary(Int32, Utf8) YES +f1 Float64 YES +f2 Utf8 YES +f3 Utf8 YES +f4 Float64 YES +time Timestamp(Nanosecond, None) YES + + +# Table m2 with a tag columns `tag_id` and `type`, a field column `f5`, and `time` +statement ok +CREATE VIEW m2 AS +SELECT + arrow_cast(column1, 'Dictionary(Int32, Utf8)') as type, + arrow_cast(column2, 'Dictionary(Int32, Utf8)') as tag_id, + arrow_cast(column3, 'Float64') as f5, + arrow_cast(column4, 'Timestamp(Nanosecond, None)') as time +FROM ( + VALUES + -- equivalent to the following line protocol data + -- m2,type=active,tag_id=1000 f5=100 1701648000000000000 + -- m2,type=active,tag_id=1000 f5=200 1701648600000000000 + -- m2,type=active,tag_id=1000 f5=300 1701649200000000000 + -- m2,type=active,tag_id=1000 f5=400 1701649800000000000 + -- m2,type=active,tag_id=1000 f5=500 1701650400000000000 + -- m2,type=active,tag_id=1000 f5=600 1701651000000000000 + -- m2,type=passive,tag_id=2000 f5=700 1701651600000000000 + -- m2,type=passive,tag_id=1000 f5=800 1701652200000000000 + -- m2,type=passive,tag_id=1000 f5=900 1701652800000000000 + -- m2,type=passive,tag_id=1000 f5=1000 1701653400000000000 + ('active', '1000', 100, 1701648000000000000), + ('active', '1000', 200, 1701648600000000000), + ('active', '1000', 300, 1701649200000000000), + ('active', '1000', 400, 1701649800000000000), + ('active', '1000', 500, 1701650400000000000), + ('active', '1000', 600, 1701651000000000000), + ('passive', '1000', 700, 1701651600000000000), + ('passive', '1000', 800, 1701652200000000000), + ('passive', '1000', 900, 1701652800000000000), + ('passive', '1000', 1000, 1701653400000000000) +); + +query ??RP +SELECT * FROM m2; +---- +active 1000 100 2023-12-04T00:00:00 +active 1000 200 2023-12-04T00:10:00 +active 1000 300 2023-12-04T00:20:00 +active 1000 400 2023-12-04T00:30:00 +active 1000 500 2023-12-04T00:40:00 +active 1000 600 2023-12-04T00:50:00 +passive 1000 700 2023-12-04T01:00:00 +passive 1000 800 2023-12-04T01:10:00 +passive 1000 900 2023-12-04T01:20:00 +passive 1000 1000 2023-12-04T01:30:00 + +query TTT +DESCRIBE m2; +---- +type Dictionary(Int32, Utf8) YES +tag_id Dictionary(Int32, Utf8) YES +f5 Float64 YES +time Timestamp(Nanosecond, None) YES + +query I +select count(*) from m1 where tag_id = '1000' and time < '2024-01-03T14:46:35+01:00'; +---- +10 + +query RRR rowsort +select min(f5), max(f5), avg(f5) from m2 where tag_id = '1000' and time < '2024-01-03T14:46:35+01:00' group by type; +---- +100 600 350 +700 1000 850 + +query IRRRP +select count(*), min(f5), max(f5), avg(f5), date_bin('30 minutes', time) as "time" +from m2 where tag_id = '1000' and time < '2024-01-03T14:46:35+01:00' +group by date_bin('30 minutes', time) +order by date_bin('30 minutes', time) DESC +---- +1 1000 1000 1000 2023-12-04T01:30:00 +3 700 900 800 2023-12-04T01:00:00 +3 400 600 500 2023-12-04T00:30:00 +3 100 300 200 2023-12-04T00:00:00 + + + +# Reproducer for https://github.com/apache/arrow-datafusion/issues/8738 +# This query should work correctly +query P?TT rowsort +SELECT + "data"."timestamp" as "time", + "data"."tag_id", + "data"."field", + "data"."value" +FROM ( + ( + SELECT "m2"."time" as "timestamp", "m2"."tag_id", 'active_power' as "field", "m2"."f5" as "value" + FROM "m2" + WHERE "m2"."time" >= '2023-12-05T14:46:35+01:00' AND "m2"."time" < '2024-01-03T14:46:35+01:00' + AND "m2"."f5" IS NOT NULL + AND "m2"."type" IN ('active') + AND "m2"."tag_id" IN ('1000') + ) UNION ( + SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f1' as "field", "m1"."f1" as "value" + FROM "m1" + WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00' + AND "m1"."f1" IS NOT NULL + AND "m1"."tag_id" IN ('1000') + ) UNION ( + SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f2' as "field", "m1"."f2" as "value" + FROM "m1" + WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00' + AND "m1"."f2" IS NOT NULL + AND "m1"."tag_id" IN ('1000') + ) +) as "data" +ORDER BY + "time", + "data"."tag_id" +; +---- +2023-12-20T00:00:00 1000 f1 32.0 +2023-12-20T00:00:00 1000 f2 foo +2023-12-20T00:10:00 1000 f1 32.0 +2023-12-20T00:10:00 1000 f2 foo +2023-12-20T00:20:00 1000 f1 32.0 +2023-12-20T00:20:00 1000 f2 foo +2023-12-20T00:30:00 1000 f1 32.0 +2023-12-20T00:30:00 1000 f2 foo +2023-12-20T00:40:00 1000 f1 32.0 +2023-12-20T00:40:00 1000 f2 foo +2023-12-20T00:50:00 1000 f1 32.0 +2023-12-20T00:50:00 1000 f2 foo +2023-12-20T01:00:00 1000 f1 32.0 +2023-12-20T01:00:00 1000 f2 foo +2023-12-20T01:10:00 1000 f1 32.0 +2023-12-20T01:10:00 1000 f2 foo +2023-12-20T01:20:00 1000 f1 32.0 +2023-12-20T01:20:00 1000 f2 foo +2023-12-20T01:30:00 1000 f1 32.0 +2023-12-20T01:30:00 1000 f2 foo + + +# deterministic sort (so we can avoid rowsort) +query P?TT +SELECT + "data"."timestamp" as "time", + "data"."tag_id", + "data"."field", + "data"."value" +FROM ( + ( + SELECT "m2"."time" as "timestamp", "m2"."tag_id", 'active_power' as "field", "m2"."f5" as "value" + FROM "m2" + WHERE "m2"."time" >= '2023-12-05T14:46:35+01:00' AND "m2"."time" < '2024-01-03T14:46:35+01:00' + AND "m2"."f5" IS NOT NULL + AND "m2"."type" IN ('active') + AND "m2"."tag_id" IN ('1000') + ) UNION ( + SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f1' as "field", "m1"."f1" as "value" + FROM "m1" + WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00' + AND "m1"."f1" IS NOT NULL + AND "m1"."tag_id" IN ('1000') + ) UNION ( + SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f2' as "field", "m1"."f2" as "value" + FROM "m1" + WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00' + AND "m1"."f2" IS NOT NULL + AND "m1"."tag_id" IN ('1000') + ) +) as "data" +ORDER BY + "time", + "data"."tag_id", + "data"."field", + "data"."value" +; +---- +2023-12-20T00:00:00 1000 f1 32.0 +2023-12-20T00:00:00 1000 f2 foo +2023-12-20T00:10:00 1000 f1 32.0 +2023-12-20T00:10:00 1000 f2 foo +2023-12-20T00:20:00 1000 f1 32.0 +2023-12-20T00:20:00 1000 f2 foo +2023-12-20T00:30:00 1000 f1 32.0 +2023-12-20T00:30:00 1000 f2 foo +2023-12-20T00:40:00 1000 f1 32.0 +2023-12-20T00:40:00 1000 f2 foo +2023-12-20T00:50:00 1000 f1 32.0 +2023-12-20T00:50:00 1000 f2 foo +2023-12-20T01:00:00 1000 f1 32.0 +2023-12-20T01:00:00 1000 f2 foo +2023-12-20T01:10:00 1000 f1 32.0 +2023-12-20T01:10:00 1000 f2 foo +2023-12-20T01:20:00 1000 f1 32.0 +2023-12-20T01:20:00 1000 f2 foo +2023-12-20T01:30:00 1000 f1 32.0 +2023-12-20T01:30:00 1000 f2 foo diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 4583ef319b7f..2a39e3138869 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -180,6 +180,7 @@ initial_logical_plan Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c --TableScan: simple_explain_test logical_plan after inline_table_scan SAME TEXT AS ABOVE +logical_plan after operator_to_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE logical_plan after count_wildcard_rule SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt new file mode 100644 index 000000000000..a2a8d9c6475c --- /dev/null +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -0,0 +1,1251 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# test_boolean_expressions +query BBBB +SELECT true, false, false = false, true = false +---- +true false true false + +# test_mathematical_expressions_with_null +query RRRRRRRRRRRRRRRRRR?RRRRRRRIRRRRRRBB +SELECT + sqrt(NULL), + cbrt(NULL), + sin(NULL), + cos(NULL), + tan(NULL), + asin(NULL), + acos(NULL), + atan(NULL), + sinh(NULL), + cosh(NULL), + tanh(NULL), + asinh(NULL), + acosh(NULL), + atanh(NULL), + floor(NULL), + ceil(NULL), + round(NULL), + trunc(NULL), + abs(NULL), + signum(NULL), + exp(NULL), + ln(NULL), + log2(NULL), + log10(NULL), + power(NULL, 2), + power(NULL, NULL), + power(2, NULL), + atan2(NULL, NULL), + atan2(1, NULL), + atan2(NULL, 1), + nanvl(NULL, NULL), + nanvl(1, NULL), + nanvl(NULL, 1), + isnan(NULL), + iszero(NULL) +---- +NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL + +# test_array_cast_invalid_timezone_will_panic +statement error Parser error: Invalid timezone "Foo": 'Foo' is not a valid timezone +SELECT arrow_cast('2021-01-02T03:04:00', 'Timestamp(Nanosecond, Some("Foo"))') + +# test_array_index +query III??IIIIII +SELECT + ([5,4,3,2,1])[1], + ([5,4,3,2,1])[2], + ([5,4,3,2,1])[5], + ([[1, 2], [2, 3], [3,4]])[1], + ([[1, 2], [2, 3], [3,4]])[3], + ([[1, 2], [2, 3], [3,4]])[1][1], + ([[1, 2], [2, 3], [3,4]])[2][2], + ([[1, 2], [2, 3], [3,4]])[3][2], + -- out of bounds + ([5,4,3,2,1])[0], + ([5,4,3,2,1])[6], + -- ([5,4,3,2,1])[-1], -- TODO: wrong answer + -- ([5,4,3,2,1])[null], -- TODO: not supported + ([5,4,3,2,1])[100] +---- +5 4 1 [1, 2] [3, 4] 1 3 4 NULL NULL NULL + +# test_array_literals +query ????? +SELECT + [1,2,3,4,5], + [true, false], + ['str1', 'str2'], + [[1,2], [3,4]], + [] +---- +[1, 2, 3, 4, 5] [true, false] [str1, str2] [[1, 2], [3, 4]] [] + +# test_struct_literals +query ?????? +SELECT + STRUCT(1,2,3,4,5), + STRUCT(Null), + STRUCT(2), + STRUCT('1',Null), + STRUCT(true, false), + STRUCT('str1', 'str2') +---- +{c0: 1, c1: 2, c2: 3, c3: 4, c4: 5} {c0: } {c0: 2} {c0: 1, c1: } {c0: true, c1: false} {c0: str1, c1: str2} + +# test binary_bitwise_shift +query IIII +SELECT + 2 << 10, + 2048 >> 10, + 2048 << NULL, + 2048 >> NULL +---- +2048 2 NULL NULL + +query ? +SELECT interval '1' +---- +0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs + +query ? +SELECT interval '1 second' +---- +0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs + +query ? +SELECT interval '500 milliseconds' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs + +query ? +SELECT interval '5 second' +---- +0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs + +query ? +SELECT interval '0.5 minute' +---- +0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs + +query ? +SELECT interval '.5 minute' +---- +0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs + +query ? +SELECT interval '5 minute' +---- +0 years 0 mons 0 days 0 hours 5 mins 0.000000000 secs + +query ? +SELECT interval '5 minute 1 second' +---- +0 years 0 mons 0 days 0 hours 5 mins 1.000000000 secs + +query ? +SELECT interval '1 hour' +---- +0 years 0 mons 0 days 1 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '5 hour' +---- +0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 day' +---- +0 years 0 mons 1 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 week' +---- +0 years 0 mons 7 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '2 weeks' +---- +0 years 0 mons 14 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 day 1' +---- +0 years 0 mons 1 days 0 hours 0 mins 1.000000000 secs + +query ? +SELECT interval '0.5' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs + +query ? +SELECT interval '0.5 day 1' +---- +0 years 0 mons 0 days 12 hours 0 mins 1.000000000 secs + +query ? +SELECT interval '0.49 day' +---- +0 years 0 mons 0 days 11 hours 45 mins 36.000000000 secs + +query ? +SELECT interval '0.499 day' +---- +0 years 0 mons 0 days 11 hours 58 mins 33.600000000 secs + +query ? +SELECT interval '0.4999 day' +---- +0 years 0 mons 0 days 11 hours 59 mins 51.360000000 secs + +query ? +SELECT interval '0.49999 day' +---- +0 years 0 mons 0 days 11 hours 59 mins 59.136000000 secs + +query ? +SELECT interval '0.49999999999 day' +---- +0 years 0 mons 0 days 11 hours 59 mins 59.999999136 secs + +query ? +SELECT interval '5 day' +---- +0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs + +# Hour is ignored, this matches PostgreSQL +query ? +SELECT interval '5 day' hour +---- +0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds' +---- +0 years 0 mons 5 days 4 hours 3 mins 2.100000000 secs + +query ? +SELECT interval '0.5 month' +---- +0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '0.5' month +---- +0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 month' +---- +0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1' MONTH +---- +0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '5 month' +---- +0 years 5 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '13 month' +---- +0 years 13 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '0.5 year' +---- +0 years 6 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 year' +---- +0 years 12 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 decade' +---- +0 years 120 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '2 decades' +---- +0 years 240 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 century' +---- +0 years 1200 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '2 year' +---- +0 years 24 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 year 1 day' +---- +0 years 12 mons 1 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 year 1 day 1 hour' +---- +0 years 12 mons 1 days 1 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 year 1 day 1 hour 1 minute' +---- +0 years 12 mons 1 days 1 hours 1 mins 0.000000000 secs + +query ? +SELECT interval '1 year 1 day 1 hour 1 minute 1 second' +---- +0 years 12 mons 1 days 1 hours 1 mins 1.000000000 secs + +query I +SELECT ascii('') +---- +0 + +query I +SELECT ascii('x') +---- +120 + +query I +SELECT ascii(NULL) +---- +NULL + +query I +SELECT bit_length('') +---- +0 + +query I +SELECT bit_length('chars') +---- +40 + +query I +SELECT bit_length('josé') +---- +40 + +query ? +SELECT bit_length(NULL) +---- +NULL + +query T +SELECT btrim(' xyxtrimyyx ', NULL) +---- +NULL + +query T +SELECT btrim(' xyxtrimyyx ') +---- +xyxtrimyyx + +query T +SELECT btrim('\n xyxtrimyyx \n') +---- +\n xyxtrimyyx \n + +query T +SELECT btrim('xyxtrimyyx', 'xyz') +---- +trim + +query T +SELECT btrim('\nxyxtrimyyx\n', 'xyz\n') +---- +trim + +query ? +SELECT btrim(NULL, 'xyz') +---- +NULL + +query T +SELECT chr(CAST(120 AS int)) +---- +x + +query T +SELECT chr(CAST(128175 AS int)) +---- +💯 + +query T +SELECT chr(CAST(NULL AS int)) +---- +NULL + +query T +SELECT concat('a','b','c') +---- +abc + +query T +SELECT concat('abcde', 2, NULL, 22) +---- +abcde222 + +query T +SELECT concat(NULL) +---- +(empty) + +query T +SELECT concat_ws(',', 'abcde', 2, NULL, 22) +---- +abcde,2,22 + +query T +SELECT concat_ws('|','a','b','c') +---- +a|b|c + +query T +SELECT concat_ws('|',NULL) +---- +(empty) + +query T +SELECT concat_ws(NULL,'a',NULL,'b','c') +---- +NULL + +query T +SELECT concat_ws('|','a',NULL) +---- +a + +query T +SELECT concat_ws('|','a',NULL,NULL) +---- +a + +query T +SELECT initcap('') +---- +(empty) + +query T +SELECT initcap('hi THOMAS') +---- +Hi Thomas + +query ? +SELECT initcap(NULL) +---- +NULL + +query T +SELECT lower('') +---- +(empty) + +query T +SELECT lower('TOM') +---- +tom + +query ? +SELECT lower(NULL) +---- +NULL + +query T +SELECT ltrim(' zzzytest ', NULL) +---- +NULL + +query T +SELECT ltrim(' zzzytest ') +---- +zzzytest + +query T +SELECT ltrim('zzzytest', 'xyz') +---- +test + +query ? +SELECT ltrim(NULL, 'xyz') +---- +NULL + +query I +SELECT octet_length('') +---- +0 + +query I +SELECT octet_length('chars') +---- +5 + +query I +SELECT octet_length('josé') +---- +5 + +query ? +SELECT octet_length(NULL) +---- +NULL + +query T +SELECT repeat('Pg', 4) +---- +PgPgPgPg + +query T +SELECT repeat('Pg', CAST(NULL AS INT)) +---- +NULL + +query ? +SELECT repeat(NULL, 4) +---- +NULL + +query T +SELECT replace('abcdefabcdef', 'cd', 'XX') +---- +abXXefabXXef + +query T +SELECT replace('abcdefabcdef', 'cd', NULL) +---- +NULL + +query T +SELECT replace('abcdefabcdef', 'notmatch', 'XX') +---- +abcdefabcdef + +query T +SELECT replace('abcdefabcdef', NULL, 'XX') +---- +NULL + +query ? +SELECT replace(NULL, 'cd', 'XX') +---- +NULL + +query T +SELECT rtrim(' testxxzx ') +---- + testxxzx + +query T +SELECT rtrim(' zzzytest ', NULL) +---- +NULL + +query T +SELECT rtrim('testxxzx', 'xyz') +---- +test + +query ? +SELECT rtrim(NULL, 'xyz') +---- +NULL + +query T +SELECT split_part('abc~@~def~@~ghi', '~@~', 2) +---- +def + +query T +SELECT split_part('abc~@~def~@~ghi', '~@~', 20) +---- +(empty) + +query ? +SELECT split_part(NULL, '~@~', 20) +---- +NULL + +query T +SELECT split_part('abc~@~def~@~ghi', NULL, 20) +---- +NULL + +query T +SELECT split_part('abc~@~def~@~ghi', '~@~', CAST(NULL AS INT)) +---- +NULL + +query B +SELECT starts_with('alphabet', 'alph') +---- +true + +query B +SELECT starts_with('alphabet', 'blph') +---- +false + +query B +SELECT starts_with(NULL, 'blph') +---- +NULL + +query B +SELECT starts_with('alphabet', NULL) +---- +NULL + +query T +SELECT to_hex(2147483647) +---- +7fffffff + +query T +SELECT to_hex(9223372036854775807) +---- +7fffffffffffffff + +query T +SELECT to_hex(CAST(NULL AS int)) +---- +NULL + +query T +SELECT trim(' tom ') +---- +tom + +query T +SELECT trim(LEADING ' tom ') +---- +tom + +query T +SELECT trim(TRAILING ' tom ') +---- + tom + +query T +SELECT trim(BOTH ' tom ') +---- +tom + +query T +SELECT trim(LEADING ' ' FROM ' tom ') +---- +tom + +query T +SELECT trim(TRAILING ' ' FROM ' tom ') +---- + tom + +query T +SELECT trim(BOTH ' ' FROM ' tom ') +---- +tom + +query T +SELECT trim(' ' FROM ' tom ') +---- +tom + +query T +SELECT trim(LEADING 'x' FROM 'xxxtomxxx') +---- +tomxxx + +query T +SELECT trim(TRAILING 'x' FROM 'xxxtomxxx') +---- +xxxtom + +query T +SELECT trim(BOTH 'x' FROM 'xxxtomxx') +---- +tom + +query T +SELECT trim('x' FROM 'xxxtomxx') +---- +tom + + +query T +SELECT trim(LEADING 'xy' FROM 'xyxabcxyzdefxyx') +---- +abcxyzdefxyx + +query T +SELECT trim(TRAILING 'xy' FROM 'xyxabcxyzdefxyx') +---- +xyxabcxyzdef + +query T +SELECT trim(BOTH 'xy' FROM 'xyxabcxyzdefxyx') +---- +abcxyzdef + +query T +SELECT trim('xy' FROM 'xyxabcxyzdefxyx') +---- +abcxyzdef + +query T +SELECT trim(' tom') +---- +tom + +query T +SELECT trim('') +---- +(empty) + +query T +SELECT trim('tom ') +---- +tom + +query T +SELECT upper('') +---- +(empty) + +query T +SELECT upper('tom') +---- +TOM + +query ? +SELECT upper(NULL) +---- +NULL + +# TODO issue: https://github.com/apache/arrow-datafusion/issues/6596 +# query ?? +#SELECT +# CAST([1,2,3,4] AS INT[]) as a, +# CAST([1,2,3,4] AS NUMERIC(10,4)[]) as b +#---- +#[1, 2, 3, 4] [1.0000, 2.0000, 3.0000, 4.0000] + +# test_random_expression +query BB +SELECT + random() BETWEEN 0.0 AND 1.0, + random() = random() +---- +true false + +# test_uuid_expression +query II +SELECT octet_length(uuid()), length(uuid()) +---- +36 36 + +# test_cast_expressions +query IIII +SELECT + CAST('0' AS INT) as a, + CAST(NULL AS INT) as b, + TRY_CAST('0' AS INT) as c, + TRY_CAST('x' AS INT) as d +---- +0 NULL 0 NULL + +# test_extract_date_part + +query R +SELECT date_part('YEAR', CAST('2000-01-01' AS DATE)) +---- +2000 + +query R +SELECT EXTRACT(year FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query R +SELECT date_part('QUARTER', CAST('2000-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(quarter FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +3 + +query R +SELECT date_part('MONTH', CAST('2000-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(month FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +9 + +query R +SELECT date_part('WEEK', CAST('2003-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(WEEK FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +37 + +query R +SELECT date_part('DAY', CAST('2000-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(day FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +8 + +query R +SELECT date_part('DOY', CAST('2000-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(doy FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +252 + +query R +SELECT date_part('DOW', CAST('2000-01-01' AS DATE)) +---- +6 + +query R +SELECT EXTRACT(dow FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +2 + +query R +SELECT date_part('HOUR', CAST('2000-01-01' AS DATE)) +---- +0 + +query R +SELECT EXTRACT(hour FROM to_timestamp('2020-09-08T12:03:03+00:00')) +---- +12 + +query R +SELECT EXTRACT(minute FROM to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query R +SELECT date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query R +SELECT EXTRACT(second FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12.12345678 + +query R +SELECT EXTRACT(millisecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123.45678 + +query R +SELECT EXTRACT(microsecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456.78 + +query R +SELECT EXTRACT(nanosecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456780 + +# Keep precision when coercing Utf8 to Timestamp +query R +SELECT date_part('second', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12.12345678 + +query R +SELECT date_part('millisecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123.45678 + +query R +SELECT date_part('microsecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456.78 + +query R +SELECT date_part('nanosecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456780 + +query R +SELECT date_part('second', '2020-09-08T12:00:12.12345678+00:00') +---- +12.12345678 + +query R +SELECT date_part('millisecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123.45678 + +query R +SELECT date_part('microsecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123456.78 + +query R +SELECT date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123456780 + +# test_extract_epoch + +query R +SELECT extract(epoch from '1870-01-01T07:29:10.256'::timestamp) +---- +-3155646649.744 + +query R +SELECT extract(epoch from '2000-01-01T00:00:00.000'::timestamp) +---- +946684800 + +query R +SELECT extract(epoch from to_timestamp('2000-01-01T00:00:00+00:00')) +---- +946684800 + +query R +SELECT extract(epoch from NULL::timestamp) +---- +NULL + +query R +SELECT extract(epoch from arrow_cast('1970-01-01', 'Date32')) +---- +0 + +query R +SELECT extract(epoch from arrow_cast('1970-01-02', 'Date32')) +---- +86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-11', 'Date32')) +---- +864000 + +query R +SELECT extract(epoch from arrow_cast('1969-12-31', 'Date32')) +---- +-86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-01', 'Date64')) +---- +0 + +query R +SELECT extract(epoch from arrow_cast('1970-01-02', 'Date64')) +---- +86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-11', 'Date64')) +---- +864000 + +query R +SELECT extract(epoch from arrow_cast('1969-12-31', 'Date64')) +---- +-86400 + +# test_extract_date_part_func + +query B +SELECT (date_part('year', now()) = EXTRACT(year FROM now())) +---- +true + +query B +SELECT (date_part('quarter', now()) = EXTRACT(quarter FROM now())) +---- +true + +query B +SELECT (date_part('month', now()) = EXTRACT(month FROM now())) +---- +true + +query B +SELECT (date_part('week', now()) = EXTRACT(week FROM now())) +---- +true + +query B +SELECT (date_part('day', now()) = EXTRACT(day FROM now())) +---- +true + +query B +SELECT (date_part('hour', now()) = EXTRACT(hour FROM now())) +---- +true + +query B +SELECT (date_part('minute', now()) = EXTRACT(minute FROM now())) +---- +true + +query B +SELECT (date_part('second', now()) = EXTRACT(second FROM now())) +---- +true + +query B +SELECT (date_part('millisecond', now()) = EXTRACT(millisecond FROM now())) +---- +true + +query B +SELECT (date_part('microsecond', now()) = EXTRACT(microsecond FROM now())) +---- +true + +query B +SELECT (date_part('nanosecond', now()) = EXTRACT(nanosecond FROM now())) +---- +true + +query B +SELECT 'a' IN ('a','b') +---- +true + +query B +SELECT 'c' IN ('a','b') +---- +false + +query B +SELECT 'c' NOT IN ('a','b') +---- +true + +query B +SELECT 'a' NOT IN ('a','b') +---- +false + +query B +SELECT NULL IN ('a','b') +---- +NULL + +query B +SELECT NULL NOT IN ('a','b') +---- +NULL + +query B +SELECT 'a' IN ('a','b',NULL) +---- +true + +query B +SELECT 'c' IN ('a','b',NULL) +---- +NULL + +query B +SELECT 'a' NOT IN ('a','b',NULL) +---- +false + +query B +SELECT 'c' NOT IN ('a','b',NULL) +---- +NULL + +query B +SELECT 0 IN (0,1,2) +---- +true + +query B +SELECT 3 IN (0,1,2) +---- +false + +query B +SELECT 3 NOT IN (0,1,2) +---- +true + +query B +SELECT 0 NOT IN (0,1,2) +---- +false + +query B +SELECT NULL IN (0,1,2) +---- +NULL + +query B +SELECT NULL NOT IN (0,1,2) +---- +NULL + +query B +SELECT 0 IN (0,1,2,NULL) +---- +true + +query B +SELECT 3 IN (0,1,2,NULL) +---- +NULL + +query B +SELECT 0 NOT IN (0,1,2,NULL) +---- +false + +query B +SELECT 3 NOT IN (0,1,2,NULL) +---- +NULL + +query B +SELECT 0.0 IN (0.0,0.1,0.2) +---- +true + +query B +SELECT 0.3 IN (0.0,0.1,0.2) +---- +false + +query B +SELECT 0.3 NOT IN (0.0,0.1,0.2) +---- +true + +query B +SELECT 0.0 NOT IN (0.0,0.1,0.2) +---- +false + +query B +SELECT NULL IN (0.0,0.1,0.2) +---- +NULL + +query B +SELECT NULL NOT IN (0.0,0.1,0.2) +---- +NULL + +query B +SELECT 0.0 IN (0.0,0.1,0.2,NULL) +---- +true + +query B +SELECT 0.3 IN (0.0,0.1,0.2,NULL) +---- +NULL + +query B +SELECT 0.0 NOT IN (0.0,0.1,0.2,NULL) +---- +false + +query B +SELECT 0.3 NOT IN (0.0,0.1,0.2,NULL) +---- +NULL + +query B +SELECT '1' IN ('a','b',1) +---- +true + +query B +SELECT '2' IN ('a','b',1) +---- +false + +query B +SELECT '2' NOT IN ('a','b',1) +---- +true + +query B +SELECT '1' NOT IN ('a','b',1) +---- +false + +query B +SELECT NULL IN ('a','b',1) +---- +NULL + +query B +SELECT NULL NOT IN ('a','b',1) +---- +NULL + +query B +SELECT '1' IN ('a','b',NULL,1) +---- +true + +query B +SELECT '2' IN ('a','b',NULL,1) +---- +NULL + +query B +SELECT '1' NOT IN ('a','b',NULL,1) +---- +false + +query B +SELECT '2' NOT IN ('a','b',NULL,1) +---- +NULL diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/group_by.slt similarity index 92% rename from datafusion/sqllogictest/test_files/groupby.slt rename to datafusion/sqllogictest/test_files/group_by.slt index b09ff79e88d5..7c5803d38594 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4284,3 +4284,433 @@ LIMIT 5 1 FRA 3 2022-01-02T12:00:00 EUR 200 1 TUR 2 2022-01-01T11:30:00 TRY 75 1 TUR 4 2022-01-03T10:00:00 TRY 100 + +# Create a table with timestamp data +statement ok +CREATE TABLE src_table ( + t1 TIMESTAMP, + c2 INT, +) AS VALUES +('2020-12-10T00:00:00.00Z', 0), +('2020-12-11T00:00:00.00Z', 1), +('2020-12-12T00:00:00.00Z', 2), +('2020-12-13T00:00:00.00Z', 3), +('2020-12-14T00:00:00.00Z', 4), +('2020-12-15T00:00:00.00Z', 5), +('2020-12-16T00:00:00.00Z', 6), +('2020-12-17T00:00:00.00Z', 7), +('2020-12-18T00:00:00.00Z', 8), +('2020-12-19T00:00:00.00Z', 9); + +# Use src_table to create a partitioned file +query PI +COPY (SELECT * FROM src_table) +TO 'test_files/scratch/group_by/timestamp_table/0.csv' +(FORMAT CSV, SINGLE_FILE_OUTPUT true); +---- +10 + +query PI +COPY (SELECT * FROM src_table) +TO 'test_files/scratch/group_by/timestamp_table/1.csv' +(FORMAT CSV, SINGLE_FILE_OUTPUT true); +---- +10 + +query PI +COPY (SELECT * FROM src_table) +TO 'test_files/scratch/group_by/timestamp_table/2.csv' +(FORMAT CSV, SINGLE_FILE_OUTPUT true); +---- +10 + +query PI +COPY (SELECT * FROM src_table) +TO 'test_files/scratch/group_by/timestamp_table/3.csv' +(FORMAT CSV, SINGLE_FILE_OUTPUT true); +---- +10 + +# Create a table from the generated CSV files: +statement ok +CREATE EXTERNAL TABLE timestamp_table ( + t1 TIMESTAMP, + c2 INT, +) +STORED AS CSV +WITH HEADER ROW +LOCATION 'test_files/scratch/group_by/timestamp_table'; + +# Group By using date_trunc +query PI rowsort +SELECT date_trunc('week', t1) as week, sum(c2) +FROM timestamp_table +GROUP BY date_trunc('week', t1) +---- +2020-12-07T00:00:00 24 +2020-12-14T00:00:00 156 + +# GROUP BY using LIMIT +query IP +SELECT c2, MAX(t1) +FROM timestamp_table +GROUP BY c2 +ORDER BY MAX(t1) DESC +LIMIT 4; +---- +9 2020-12-19T00:00:00 +8 2020-12-18T00:00:00 +7 2020-12-17T00:00:00 +6 2020-12-16T00:00:00 + +# Explain the GROUP BY with LIMIT to ensure the plan contains `lim=[4]` +query TT +EXPLAIN +SELECT c2, MAX(t1) +FROM timestamp_table +GROUP BY c2 +ORDER BY MAX(t1) DESC +LIMIT 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MAX(timestamp_table.t1) DESC NULLS FIRST, fetch=4 +----Aggregate: groupBy=[[timestamp_table.c2]], aggr=[[MAX(timestamp_table.t1)]] +------TableScan: timestamp_table projection=[t1, c2] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MAX(timestamp_table.t1)@1 DESC], fetch=4 +----SortExec: TopK(fetch=4), expr=[MAX(timestamp_table.t1)@1 DESC] +------AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[MAX(timestamp_table.t1)], lim=[4] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([c2@0], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[c2@1 as c2], aggr=[MAX(timestamp_table.t1)], lim=[4] +--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4 +----------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/group_by/timestamp_table/0.csv], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/group_by/timestamp_table/1.csv], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/group_by/timestamp_table/2.csv], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/group_by/timestamp_table/3.csv]]}, projection=[t1, c2], has_header=true + +# Clean up +statement ok +DROP TABLE src_table; + +statement ok +DROP TABLE timestamp_table; + +### BEGIN Group By with Dictionary Variants ### +# +# The following tests use GROUP BY on tables with dictionary columns. +# The same test is repeated using dictionaries with the key types: +# +# - Int8 +# - Int16 +# - Int32 +# - Int64 +# - UInt8 +# - UInt16 +# - UInt32 +# - UInt64 + +# Table with an int column and Dict column: +statement ok +CREATE TABLE int8_dict AS VALUES +(1, arrow_cast('A', 'Dictionary(Int8, Utf8)')), +(2, arrow_cast('B', 'Dictionary(Int8, Utf8)')), +(2, arrow_cast('A', 'Dictionary(Int8, Utf8)')), +(4, arrow_cast('A', 'Dictionary(Int8, Utf8)')), +(1, arrow_cast('C', 'Dictionary(Int8, Utf8)')), +(1, arrow_cast('A', 'Dictionary(Int8, Utf8)')); + +# Group by the non-dict column +query ?I rowsort +SELECT column2, count(column1) FROM int8_dict GROUP BY column2; +---- +A 4 +B 1 +C 1 + +# Group by the value column with dict as aggregate +query II rowsort +SELECT column1, count(column2) FROM int8_dict GROUP BY column1; +---- +1 3 +2 2 +4 1 + +# Group by with dict as aggregate using distinct +query II rowsort +SELECT column1, count(distinct column2) FROM int8_dict GROUP BY column1; +---- +1 2 +2 2 +4 1 + +# Clean up +statement ok +DROP TABLE int8_dict; + +# Table with an int column and Dict column: +statement ok +CREATE TABLE int16_dict AS VALUES +(1, arrow_cast('A', 'Dictionary(Int16, Utf8)')), +(2, arrow_cast('B', 'Dictionary(Int16, Utf8)')), +(2, arrow_cast('A', 'Dictionary(Int16, Utf8)')), +(4, arrow_cast('A', 'Dictionary(Int16, Utf8)')), +(1, arrow_cast('C', 'Dictionary(Int16, Utf8)')), +(1, arrow_cast('A', 'Dictionary(Int16, Utf8)')); + +# Group by the non-dict column +query ?I rowsort +SELECT column2, count(column1) FROM int16_dict GROUP BY column2; +---- +A 4 +B 1 +C 1 + +# Group by the value column with dict as aggregate +query II rowsort +SELECT column1, count(column2) FROM int16_dict GROUP BY column1; +---- +1 3 +2 2 +4 1 + +# Group by with dict as aggregate using distinct +query II rowsort +SELECT column1, count(distinct column2) FROM int16_dict GROUP BY column1; +---- +1 2 +2 2 +4 1 + +# Clean up +statement ok +DROP TABLE int16_dict; + +# Table with an int column and Dict column: +statement ok +CREATE TABLE int32_dict AS VALUES +(1, arrow_cast('A', 'Dictionary(Int32, Utf8)')), +(2, arrow_cast('B', 'Dictionary(Int32, Utf8)')), +(2, arrow_cast('A', 'Dictionary(Int32, Utf8)')), +(4, arrow_cast('A', 'Dictionary(Int32, Utf8)')), +(1, arrow_cast('C', 'Dictionary(Int32, Utf8)')), +(1, arrow_cast('A', 'Dictionary(Int32, Utf8)')); + +# Group by the non-dict column +query ?I rowsort +SELECT column2, count(column1) FROM int32_dict GROUP BY column2; +---- +A 4 +B 1 +C 1 + +# Group by the value column with dict as aggregate +query II rowsort +SELECT column1, count(column2) FROM int32_dict GROUP BY column1; +---- +1 3 +2 2 +4 1 + +# Group by with dict as aggregate using distinct +query II rowsort +SELECT column1, count(distinct column2) FROM int32_dict GROUP BY column1; +---- +1 2 +2 2 +4 1 + +# Clean up +statement ok +DROP TABLE int32_dict; + +# Table with an int column and Dict column: +statement ok +CREATE TABLE int64_dict AS VALUES +(1, arrow_cast('A', 'Dictionary(Int64, Utf8)')), +(2, arrow_cast('B', 'Dictionary(Int64, Utf8)')), +(2, arrow_cast('A', 'Dictionary(Int64, Utf8)')), +(4, arrow_cast('A', 'Dictionary(Int64, Utf8)')), +(1, arrow_cast('C', 'Dictionary(Int64, Utf8)')), +(1, arrow_cast('A', 'Dictionary(Int64, Utf8)')); + +# Group by the non-dict column +query ?I rowsort +SELECT column2, count(column1) FROM int64_dict GROUP BY column2; +---- +A 4 +B 1 +C 1 + +# Group by the value column with dict as aggregate +query II rowsort +SELECT column1, count(column2) FROM int64_dict GROUP BY column1; +---- +1 3 +2 2 +4 1 + +# Group by with dict as aggregate using distinct +query II rowsort +SELECT column1, count(distinct column2) FROM int64_dict GROUP BY column1; +---- +1 2 +2 2 +4 1 + +# Clean up +statement ok +DROP TABLE int64_dict; + +# Table with an int column and Dict column: +statement ok +CREATE TABLE uint8_dict AS VALUES +(1, arrow_cast('A', 'Dictionary(UInt8, Utf8)')), +(2, arrow_cast('B', 'Dictionary(UInt8, Utf8)')), +(2, arrow_cast('A', 'Dictionary(UInt8, Utf8)')), +(4, arrow_cast('A', 'Dictionary(UInt8, Utf8)')), +(1, arrow_cast('C', 'Dictionary(UInt8, Utf8)')), +(1, arrow_cast('A', 'Dictionary(UInt8, Utf8)')); + +# Group by the non-dict column +query ?I rowsort +SELECT column2, count(column1) FROM uint8_dict GROUP BY column2; +---- +A 4 +B 1 +C 1 + +# Group by the value column with dict as aggregate +query II rowsort +SELECT column1, count(column2) FROM uint8_dict GROUP BY column1; +---- +1 3 +2 2 +4 1 + +# Group by with dict as aggregate using distinct +query II rowsort +SELECT column1, count(distinct column2) FROM uint8_dict GROUP BY column1; +---- +1 2 +2 2 +4 1 + +# Clean up +statement ok +DROP TABLE uint8_dict; + +# Table with an int column and Dict column: +statement ok +CREATE TABLE uint16_dict AS VALUES +(1, arrow_cast('A', 'Dictionary(UInt16, Utf8)')), +(2, arrow_cast('B', 'Dictionary(UInt16, Utf8)')), +(2, arrow_cast('A', 'Dictionary(UInt16, Utf8)')), +(4, arrow_cast('A', 'Dictionary(UInt16, Utf8)')), +(1, arrow_cast('C', 'Dictionary(UInt16, Utf8)')), +(1, arrow_cast('A', 'Dictionary(UInt16, Utf8)')); + +# Group by the non-dict column +query ?I rowsort +SELECT column2, count(column1) FROM uint16_dict GROUP BY column2; +---- +A 4 +B 1 +C 1 + +# Group by the value column with dict as aggregate +query II rowsort +SELECT column1, count(column2) FROM uint16_dict GROUP BY column1; +---- +1 3 +2 2 +4 1 + +# Group by with dict as aggregate using distinct +query II rowsort +SELECT column1, count(distinct column2) FROM uint16_dict GROUP BY column1; +---- +1 2 +2 2 +4 1 + +# Clean up +statement ok +DROP TABLE uint16_dict; + +# Table with an int column and Dict column: +statement ok +CREATE TABLE uint32_dict AS VALUES +(1, arrow_cast('A', 'Dictionary(UInt32, Utf8)')), +(2, arrow_cast('B', 'Dictionary(UInt32, Utf8)')), +(2, arrow_cast('A', 'Dictionary(UInt32, Utf8)')), +(4, arrow_cast('A', 'Dictionary(UInt32, Utf8)')), +(1, arrow_cast('C', 'Dictionary(UInt32, Utf8)')), +(1, arrow_cast('A', 'Dictionary(UInt32, Utf8)')); + +# Group by the non-dict column +query ?I rowsort +SELECT column2, count(column1) FROM uint32_dict GROUP BY column2; +---- +A 4 +B 1 +C 1 + +# Group by the value column with dict as aggregate +query II rowsort +SELECT column1, count(column2) FROM uint32_dict GROUP BY column1; +---- +1 3 +2 2 +4 1 + +# Group by with dict as aggregate using distinct +query II rowsort +SELECT column1, count(distinct column2) FROM uint32_dict GROUP BY column1; +---- +1 2 +2 2 +4 1 + +# Clean up +statement ok +DROP TABLE uint32_dict; + +# Table with an int column and Dict column: +statement ok +CREATE TABLE uint64_dict AS VALUES +(1, arrow_cast('A', 'Dictionary(UInt64, Utf8)')), +(2, arrow_cast('B', 'Dictionary(UInt64, Utf8)')), +(2, arrow_cast('A', 'Dictionary(UInt64, Utf8)')), +(4, arrow_cast('A', 'Dictionary(UInt64, Utf8)')), +(1, arrow_cast('C', 'Dictionary(UInt64, Utf8)')), +(1, arrow_cast('A', 'Dictionary(UInt64, Utf8)')); + +# Group by the non-dict column +query ?I rowsort +SELECT column2, count(column1) FROM uint64_dict GROUP BY column2; +---- +A 4 +B 1 +C 1 + +# Group by the value column with dict as aggregate +query II rowsort +SELECT column1, count(column2) FROM uint64_dict GROUP BY column1; +---- +1 3 +2 2 +4 1 + +# Group by with dict as aggregate using distinct +query II rowsort +SELECT column1, count(distinct column2) FROM uint64_dict GROUP BY column1; +---- +1 2 +2 2 +4 1 + +# Clean up +statement ok +DROP TABLE uint64_dict; + +### END Group By with Dictionary Variants ### diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 1b5ad86546a3..f8893bf7ae5c 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -242,13 +242,13 @@ datafusion.execution.parquet.dictionary_enabled NULL Sets if dictionary encoding datafusion.execution.parquet.dictionary_page_size_limit 1048576 Sets best effort maximum dictionary page size, in bytes datafusion.execution.parquet.enable_page_index true If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. datafusion.execution.parquet.encoding NULL Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting -datafusion.execution.parquet.max_row_group_size 1048576 Sets maximum number of rows in a row group +datafusion.execution.parquet.max_row_group_size 1048576 Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. datafusion.execution.parquet.max_statistics_size NULL Sets max statistics size for any column. If NULL, uses default parquet writer setting datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. datafusion.execution.parquet.maximum_parallel_row_group_writers 1 By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. datafusion.execution.parquet.metadata_size_hint NULL If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer datafusion.execution.parquet.pruning true If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file -datafusion.execution.parquet.pushdown_filters false If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded +datafusion.execution.parquet.pushdown_filters false If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". datafusion.execution.parquet.reorder_filters false If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query datafusion.execution.parquet.skip_metadata true If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata datafusion.execution.parquet.statistics_enabled NULL Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index c9dd7ca604ad..ca9b918ff3ee 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -626,6 +626,38 @@ Alice 100 Alice 1 Alice 50 Alice 2 Alice 100 Alice 2 +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +set datafusion.optimizer.repartition_joins = true; + +# make sure when target partition is 1, hash repartition is not added +# to the final plan. +query TT +EXPLAIN SELECT * +FROM t1, +t1 as t2 +WHERE t1.a=t2.a; +---- +logical_plan +Inner Join: t1.a = t2.a +--TableScan: t1 projection=[a, b] +--SubqueryAlias: t2 +----TableScan: t1 projection=[a, b] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0)] +----MemoryExec: partitions=1, partition_sizes=[1] +----MemoryExec: partitions=1, partition_sizes=[1] + +# Reset the configs to old values +statement ok +set datafusion.execution.target_partitions = 4; + +statement ok +set datafusion.optimizer.repartition_joins = false; + statement ok DROP TABLE t1; diff --git a/datafusion/sqllogictest/test_files/repartition.slt b/datafusion/sqllogictest/test_files/repartition.slt new file mode 100644 index 000000000000..9829299f43e5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/repartition.slt @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +# Tests for repartitioning +########## + +# Set 4 partitions for deterministic output plans +statement ok +set datafusion.execution.target_partitions = 4; + +statement ok +COPY (VALUES (1, 2), (2, 5), (3, 2), (4, 5), (5, 0)) TO 'test_files/scratch/repartition/parquet_table/2.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +statement ok +CREATE EXTERNAL TABLE parquet_table(column1 int, column2 int) +STORED AS PARQUET +LOCATION 'test_files/scratch/repartition/parquet_table/'; + +# enable round robin repartitioning +statement ok +set datafusion.optimizer.enable_round_robin_repartition = true; + +query TT +EXPLAIN SELECT column1, SUM(column2) FROM parquet_table GROUP BY column1; +---- +logical_plan +Aggregate: groupBy=[[parquet_table.column1]], aggr=[[SUM(CAST(parquet_table.column2 AS Int64))]] +--TableScan: parquet_table projection=[column1, column2] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[column1@0 as column1], aggr=[SUM(parquet_table.column2)] +--CoalesceBatchesExec: target_batch_size=8192 +----RepartitionExec: partitioning=Hash([column1@0], 4), input_partitions=4 +------AggregateExec: mode=Partial, gby=[column1@0 as column1], aggr=[SUM(parquet_table.column2)] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition/parquet_table/2.parquet]]}, projection=[column1, column2] + +# disable round robin repartitioning +statement ok +set datafusion.optimizer.enable_round_robin_repartition = false; + +query TT +EXPLAIN SELECT column1, SUM(column2) FROM parquet_table GROUP BY column1; +---- +logical_plan +Aggregate: groupBy=[[parquet_table.column1]], aggr=[[SUM(CAST(parquet_table.column2 AS Int64))]] +--TableScan: parquet_table projection=[column1, column2] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[column1@0 as column1], aggr=[SUM(parquet_table.column2)] +--CoalesceBatchesExec: target_batch_size=8192 +----RepartitionExec: partitioning=Hash([column1@0], 4), input_partitions=1 +------AggregateExec: mode=Partial, gby=[column1@0 as column1], aggr=[SUM(parquet_table.column2)] +--------ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition/parquet_table/2.parquet]]}, projection=[column1, column2] + + +# Cleanup +statement ok +DROP TABLE parquet_table; diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 02eccd7c5d06..73487635e9cb 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -61,7 +61,27 @@ Filter: parquet_table.column1 != Int32(42) physical_plan CoalesceBatchesExec: target_batch_size=8192 --FilterExec: column1@0 != 42 -----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..101], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:101..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:202..303], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:303..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..101], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:101..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:202..303], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:303..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] + +# disable round robin repartitioning +statement ok +set datafusion.optimizer.enable_round_robin_repartition = false; + +## Expect to see the scan read the file as "4" groups with even sizes (offsets) again +query TT +EXPLAIN SELECT column1 FROM parquet_table WHERE column1 <> 42; +---- +logical_plan +Filter: parquet_table.column1 != Int32(42) +--TableScan: parquet_table projection=[column1], partial_filters=[parquet_table.column1 != Int32(42)] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: column1@0 != 42 +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..101], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:101..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:202..303], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:303..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] + +# enable round robin repartitioning again +statement ok +set datafusion.optimizer.enable_round_robin_repartition = true; # create a second parquet file statement ok @@ -82,7 +102,7 @@ SortPreservingMergeExec: [column1@0 ASC NULLS LAST] --SortExec: expr=[column1@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=8192 ------FilterExec: column1@0 != 42 ---------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..200], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:200..394, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..206], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:206..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 +--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..200], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:200..394, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..206], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:206..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] ## Read the files as though they are ordered @@ -118,7 +138,7 @@ physical_plan SortPreservingMergeExec: [column1@0 ASC NULLS LAST] --CoalesceBatchesExec: target_batch_size=8192 ----FilterExec: column1@0 != 42 -------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..197], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..201], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:201..403], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:197..394]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 +------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..197], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..201], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:201..403], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:197..394]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] # Cleanup statement ok @@ -147,7 +167,7 @@ WITH HEADER ROW LOCATION 'test_files/scratch/repartition_scan/csv_table/'; query I -select * from csv_table; +select * from csv_table ORDER BY column1; ---- 1 2 @@ -190,7 +210,7 @@ STORED AS json LOCATION 'test_files/scratch/repartition_scan/json_table/'; query I -select * from "json_table"; +select * from "json_table" ORDER BY column1; ---- 1 2 diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index ea570b99d4dd..132bcdd246fe 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -114,6 +114,14 @@ VALUES (1,2,3,4,5,6,7,8,9,10,11,12,13,NULL,'F',3.5) ---- 1 2 3 4 5 6 7 8 9 10 11 12 13 NULL F 3.5 +# Test non-literal expressions in VALUES +query II +VALUES (1, CASE WHEN RANDOM() > 0.5 THEN 1 ELSE 1 END), + (2, CASE WHEN RANDOM() > 0.5 THEN 2 ELSE 2 END); +---- +1 1 +2 2 + query IT SELECT * FROM (VALUES (1,'a'),(2,NULL)) AS t(c1, c2) ---- diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index c84e46c965fa..7829ce53ac9a 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -1862,7 +1862,7 @@ SELECT to_timestamp(null) is null as c1, ---- true true true true true true true true true true true true true -# verify timestamp output types +# verify timestamp output types query TTT SELECT arrow_typeof(to_timestamp(1)), arrow_typeof(to_timestamp(null)), arrow_typeof(to_timestamp('2023-01-10 12:34:56.000')) ---- @@ -1880,7 +1880,7 @@ SELECT arrow_typeof(to_timestamp(1)) = arrow_typeof(1::timestamp) as c1, true true true true true true # known issues. currently overflows (expects default precision to be microsecond instead of nanoseconds. Work pending) -#verify extreme values +#verify extreme values #query PPPPPPPP #SELECT to_timestamp(-62125747200), to_timestamp(1926632005177), -62125747200::timestamp, 1926632005177::timestamp, cast(-62125747200 as timestamp), cast(1926632005177 as timestamp) #---- @@ -1894,3 +1894,56 @@ query B select arrow_cast(now(), 'Date64') < arrow_cast('2022-02-02 02:02:02', 'Timestamp(Nanosecond, None)'); ---- false + +########## +## Test query MAX Timestamp and MiN Timestamp +########## + +statement ok +create table table_a (val int, ts timestamp) as values (1, '2020-09-08T11:42:29.190'::timestamp), (2, '2000-02-01T00:00:00'::timestamp) + +query P +SELECT MIN(table_a.ts) FROM table_a; +---- +2000-02-01T00:00:00 + +query P +SELECT MAX(table_a.ts) FROM table_a; +---- +2020-09-08T11:42:29.190 + +statement ok +drop table table_a + +########## +## Test query MAX Timestamp and MiN Timestamp +########## + +statement ok +create table table_a (ts timestamp) as values + ('2020-09-08T11:42:29Z'::timestamp), + ('2020-09-08T12:42:29Z'::timestamp), + ('2020-09-08T13:42:29Z'::timestamp) + +statement ok +create table table_b (ts timestamp) as values + ('2020-09-08T11:42:29.190Z'::timestamp), + ('2020-09-08T13:42:29.190Z'::timestamp), + ('2020-09-08T12:42:29.190Z'::timestamp) + +query PPB +SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b order by table_a.ts desc, table_b.ts desc +---- +2020-09-08T13:42:29 2020-09-08T13:42:29.190 false +2020-09-08T13:42:29 2020-09-08T12:42:29.190 false +2020-09-08T13:42:29 2020-09-08T11:42:29.190 false +2020-09-08T12:42:29 2020-09-08T13:42:29.190 false +2020-09-08T12:42:29 2020-09-08T12:42:29.190 false +2020-09-08T12:42:29 2020-09-08T11:42:29.190 false +2020-09-08T11:42:29 2020-09-08T13:42:29.190 false +2020-09-08T11:42:29 2020-09-08T12:42:29.190 false +2020-09-08T11:42:29 2020-09-08T11:42:29.190 false + + + + diff --git a/datafusion/sqllogictest/test_files/tpch/q2.slt.part b/datafusion/sqllogictest/test_files/tpch/q2.slt.part index ed439348d22d..ed950db190bb 100644 --- a/datafusion/sqllogictest/test_files/tpch/q2.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q2.slt.part @@ -238,7 +238,7 @@ order by p_partkey limit 10; ---- -9828.21 Supplier#000000647 UNITED KINGDOM 13120 Manufacturer#5 x5U7MBZmwfG9 33-258-202-4782 s the slyly even ideas poach fluffily +9828.21 Supplier#000000647 UNITED KINGDOM 13120 Manufacturer#5 x5U7MBZmwfG9 33-258-202-4782 s the slyly even ideas poach fluffily 9508.37 Supplier#000000070 FRANCE 3563 Manufacturer#1 INWNH2w,OOWgNDq0BRCcBwOMQc6PdFDc4 16-821-608-1166 ests sleep quickly express ideas. ironic ideas haggle about the final T 9508.37 Supplier#000000070 FRANCE 17268 Manufacturer#4 INWNH2w,OOWgNDq0BRCcBwOMQc6PdFDc4 16-821-608-1166 ests sleep quickly express ideas. ironic ideas haggle about the final T 9453.01 Supplier#000000802 ROMANIA 10021 Manufacturer#5 ,6HYXb4uaHITmtMBj4Ak57Pd 29-342-882-6463 gular frets. permanently special multipliers believe blithely alongs diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 7d6d59201396..100c2143837a 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3794,7 +3794,7 @@ select a, 1 1 2 1 -# support scalar value in ORDER BY +# support scalar value in ORDER BY query I select rank() over (order by 1) rnk from (select 1 a union all select 2 a) x ---- diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 0a9a6e8dd12b..160af37ef961 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -35,7 +35,9 @@ itertools = { workspace = true } object_store = { workspace = true } prost = "0.12" prost-types = "0.12" -substrait = "0.21.0" +substrait = "0.22.1" + +[dev-dependencies] tokio = "1.17" [features] diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index c5f795d0653a..91af15a6ea62 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -28,7 +28,7 @@ authors = { workspace = true } rust-version = "1.70" [lib] -crate-type = ["cdylib", "rlib",] +crate-type = ["cdylib", "rlib"] [dependencies] @@ -37,11 +37,14 @@ crate-type = ["cdylib", "rlib",] # all the `std::fmt` and `std::panicking` infrastructure, so isn't great for # code size when deploying. console_error_panic_hook = { version = "0.1.1", optional = true } +datafusion = { path = "../core", default-features = false } datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } +datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } # getrandom must be compiled with js feature diff --git a/datafusion/wasmtest/README.md b/datafusion/wasmtest/README.md index d26369a18ab9..4af0f94db9e9 100644 --- a/datafusion/wasmtest/README.md +++ b/datafusion/wasmtest/README.md @@ -59,10 +59,13 @@ Then open http://localhost:8080/ in a web browser and check the console to see t The following DataFusion crates are verified to work in a wasm-pack environment using the default `wasm32-unknown-unknown` target: +- `datafusion` (datafusion-core) with default-features disabled to remove `bzip2-sys` from `async-compression` - `datafusion-common` with default-features disabled to remove the `parquet` dependency (see below) - `datafusion-expr` +- `datafusion-execution` - `datafusion-optimizer` - `datafusion-physical-expr` +- `datafusion-physical-plan` - `datafusion-sql` The difficulty with getting the remaining DataFusion crates compiled to WASM is that they have non-optional dependencies on the [`parquet`](https://docs.rs/crate/parquet/) crate with its default features enabled. Several of the default parquet crate features require native dependencies that are not compatible with WASM, in particular the `lz4` and `zstd` features. If we can arrange our feature flags to make it possible to depend on parquet with these features disabled, then it should be possible to compile the core `datafusion` crate to WASM as well. diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index c7b90cf05f1b..5163c99bd5ac 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -20,8 +20,7 @@ }, "../pkg": { "name": "datafusion-wasmtest", - "version": "31.0.0", - "license": "Apache-2.0" + "version": "0.0.1" }, "node_modules/@discoveryjs/json-ext": { "version": "0.5.7", @@ -1667,9 +1666,9 @@ } }, "node_modules/follow-redirects": { - "version": "1.15.3", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.3.tgz", - "integrity": "sha512-1VzOtuEM8pC9SFU1E+8KfTjZyMztRsgEfwQl44z8A25uy13jSzTj6dyK2Df52iV0vgHCfBwLhDWevLn95w5v6Q==", + "version": "1.15.4", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", + "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", "dev": true, "funding": [ { @@ -3324,7 +3323,7 @@ }, "node_modules/serve-index/node_modules/http-errors": { "version": "1.6.3", - "resolved": "http://registry.npmjs.org/http-errors/-/http-errors-1.6.3.tgz", + "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-1.6.3.tgz", "integrity": "sha1-i1VoC7S+KDoLW/TqLjhYC+HZMg0=", "dev": true, "dependencies": { @@ -5581,9 +5580,9 @@ } }, "follow-redirects": { - "version": "1.15.3", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.3.tgz", - "integrity": "sha512-1VzOtuEM8pC9SFU1E+8KfTjZyMztRsgEfwQl44z8A25uy13jSzTj6dyK2Df52iV0vgHCfBwLhDWevLn95w5v6Q==", + "version": "1.15.4", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", + "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", "dev": true }, "forwarded": { @@ -6784,7 +6783,7 @@ }, "http-errors": { "version": "1.6.3", - "resolved": "http://registry.npmjs.org/http-errors/-/http-errors-1.6.3.tgz", + "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-1.6.3.tgz", "integrity": "sha1-i1VoC7S+KDoLW/TqLjhYC+HZMg0=", "dev": true, "requires": { diff --git a/docs/logos/DataFUSION-Logo-Dark.svg b/docs/logos/DataFUSION-Logo-Dark.svg new file mode 100644 index 000000000000..e16f244430e6 --- /dev/null +++ b/docs/logos/DataFUSION-Logo-Dark.svg @@ -0,0 +1 @@ +DataFUSION-Logo-Dark \ No newline at end of file diff --git a/docs/logos/DataFUSION-Logo-Dark@2x.png b/docs/logos/DataFUSION-Logo-Dark@2x.png new file mode 100644 index 000000000000..cc60f12a0e4f Binary files /dev/null and b/docs/logos/DataFUSION-Logo-Dark@2x.png differ diff --git a/docs/logos/DataFUSION-Logo-Dark@4x.png b/docs/logos/DataFUSION-Logo-Dark@4x.png new file mode 100644 index 000000000000..0503c216ac84 Binary files /dev/null and b/docs/logos/DataFUSION-Logo-Dark@4x.png differ diff --git a/docs/logos/DataFUSION-Logo-Light.svg b/docs/logos/DataFUSION-Logo-Light.svg new file mode 100644 index 000000000000..b3bef2193dde --- /dev/null +++ b/docs/logos/DataFUSION-Logo-Light.svg @@ -0,0 +1 @@ +DataFUSION-Logo-Light \ No newline at end of file diff --git a/docs/logos/DataFUSION-Logo-Light@2x.png b/docs/logos/DataFUSION-Logo-Light@2x.png new file mode 100644 index 000000000000..8992213b0e60 Binary files /dev/null and b/docs/logos/DataFUSION-Logo-Light@2x.png differ diff --git a/docs/logos/DataFUSION-Logo-Light@4x.png b/docs/logos/DataFUSION-Logo-Light@4x.png new file mode 100644 index 000000000000..bd329ca21956 Binary files /dev/null and b/docs/logos/DataFUSION-Logo-Light@4x.png differ diff --git a/docs/logos/DataFusion-LogoAndColorPaletteExploration_v01.pdf b/docs/logos/DataFusion-LogoAndColorPaletteExploration_v01.pdf new file mode 100644 index 000000000000..4594c50f9044 Binary files /dev/null and b/docs/logos/DataFusion-LogoAndColorPaletteExploration_v01.pdf differ diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index a9d0f30bcf8e..9f7880049856 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -3,3 +3,24 @@ {# Silence the navbar #} {% block docs_navbar %} {% endblock %} + + +{% block footer %} + + + +{% endblock %} diff --git a/docs/source/conf.py b/docs/source/conf.py index 3fa6c6091d6f..becece330d1a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -33,9 +33,9 @@ # -- Project information ----------------------------------------------------- -project = 'Arrow DataFusion' -copyright = '2023, Apache Software Foundation' -author = 'Arrow DataFusion Authors' +project = 'Apache Arrow DataFusion' +copyright = '2019-2024, Apache Software Foundation' +author = 'Apache Software Foundation' # -- General configuration --------------------------------------------------- diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index 8d69ade83d72..cb0fe63abd91 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -95,7 +95,7 @@ Compiling DataFusion from sources requires an installed version of the protobuf On most platforms this can be installed from your system's package manager ``` -$ apt install -y protobuf-compiler +$ sudo apt install -y protobuf-compiler $ dnf install -y protobuf-devel $ pacman -S protobuf $ brew install protobuf diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index 1f687f978f30..64dc25411deb 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -398,7 +398,8 @@ impl Accumulator for GeometricMean { ### registering an Aggregate UDF -To register a Aggreate UDF, you need to wrap the function implementation in a `AggregateUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udaf` helper functions to make this easier. +To register a Aggreate UDF, you need to wrap the function implementation in a [`AggregateUDF`] struct and then register it with the `SessionContext`. DataFusion provides the [`create_udaf`] helper functions to make this easier. +There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udaf.rs`]. ```rust use datafusion::logical_expr::{Volatility, create_udaf}; @@ -421,6 +422,10 @@ let geometric_mean = create_udaf( ); ``` +[`aggregateudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.AggregateUDF.html +[`create_udaf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udaf.html +[`advanced_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs + The `create_udaf` has six arguments to check: - The first argument is the name of the function. This is the name that will be used in SQL queries. diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 0a5c221c5034..7111ea1d0ab5 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -53,7 +53,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.pruning | true | If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | | datafusion.execution.parquet.skip_metadata | true | If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | | datafusion.execution.parquet.metadata_size_hint | NULL | If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | -| datafusion.execution.parquet.pushdown_filters | false | If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded | +| datafusion.execution.parquet.pushdown_filters | false | If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | | datafusion.execution.parquet.reorder_filters | false | If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | | datafusion.execution.parquet.data_pagesize_limit | 1048576 | Sets best effort maximum size of data page in bytes | | datafusion.execution.parquet.write_batch_size | 1024 | Sets write_batch_size in bytes | @@ -63,7 +63,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | Sets best effort maximum dictionary page size, in bytes | | datafusion.execution.parquet.statistics_enabled | NULL | Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_statistics_size | NULL | Sets max statistics size for any column. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.max_row_group_size | 1048576 | Sets maximum number of rows in a row group | +| datafusion.execution.parquet.max_row_group_size | 1048576 | Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | | datafusion.execution.parquet.created_by | datafusion version 34.0.0 | Sets "created by" property | | datafusion.execution.parquet.column_index_truncate_length | NULL | Sets column index truncate length | | datafusion.execution.parquet.data_page_row_count_limit | 18446744073709551615 | Sets best effort maximum number of rows in data page | diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index b8689e556741..85322d9fa766 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -237,6 +237,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | | array_union(array1, array2) | Returns an array of the elements in the union of array1 and array2 without duplicates. `array_union([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2, 3, 4, 5, 6]` | | array_except(array1, array2) | Returns an array of the elements that appear in the first array but not in the second. `array_except([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | +| array_resize(array, size, value) | Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set. `array_resize([1, 2, 3], 5, 0) -> [1, 2, 3, 4, 5, 6]` | | cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | | make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | | range(start [, stop, step]) | Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 629a5f6ecb88..9dd008f8fc44 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1970,7 +1970,7 @@ array_prepend(element, array) Returns the array without the first element. ``` -array_pop_first(array) +array_pop_front(array) ``` #### Arguments @@ -1981,9 +1981,9 @@ array_pop_first(array) #### Example ``` -❯ select array_pop_first([1, 2, 3]); +❯ select array_pop_front([1, 2, 3]); +-------------------------------+ -| array_pop_first(List([1,2,3])) | +| array_pop_front(List([1,2,3])) | +-------------------------------+ | [2, 3] | +-------------------------------+ diff --git a/docs/source/user-guide/sql/write_options.md b/docs/source/user-guide/sql/write_options.md index 470591afafff..75aa0d77b95c 100644 --- a/docs/source/user-guide/sql/write_options.md +++ b/docs/source/user-guide/sql/write_options.md @@ -69,9 +69,9 @@ In this example, we write the entirety of `source_table` out to a folder of parq The following special options are specific to the `COPY` command. | Option | Description | Default Value | -| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | --- | +| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | | SINGLE_FILE_OUTPUT | If true, COPY query will write output to a single file. Otherwise, multiple files will be written to a directory in parallel. | true | -| FORMAT | Specifies the file format COPY query will write out. If single_file_output is false or the format cannot be inferred from the file extension, then FORMAT must be specified. | N/A | | +| FORMAT | Specifies the file format COPY query will write out. If single_file_output is false or the format cannot be inferred from the file extension, then FORMAT must be specified. | N/A | ### JSON Format Specific Options @@ -100,21 +100,21 @@ The following options are available when writing CSV files. Note: if any unsuppo The following options are available when writing parquet files. If any unsupported option is specified an error will be raised and the query will fail. If a column specific option is specified for a column which does not exist, the option will be ignored without error. For default values, see: [Configuration Settings](https://arrow.apache.org/datafusion/user-guide/configs.html). -| Option | Can be Column Specific? | Description | -| ---------------------------- | ----------------------- | ------------------------------------------------------------------------------------------------------------- | -| COMPRESSION | Yes | Sets the compression codec and if applicable compression level to use | -| MAX_ROW_GROUP_SIZE | No | Sets the maximum number of rows that can be encoded in a single row group | -| DATA_PAGESIZE_LIMIT | No | Sets the best effort maximum page size in bytes | -| WRITE_BATCH_SIZE | No | Maximum number of rows written for each column in a single batch | -| WRITER_VERSION | No | Parquet writer version (1.0 or 2.0) | -| DICTIONARY_PAGE_SIZE_LIMIT | No | Sets best effort maximum dictionary page size in bytes | -| CREATED_BY | No | Sets the "created by" property in the parquet file | -| COLUMN_INDEX_TRUNCATE_LENGTH | No | Sets the max length of min/max value fields in the column index. | -| DATA_PAGE_ROW_COUNT_LIMIT | No | Sets best effort maximum number of rows in a data page. | -| BLOOM_FILTER_ENABLED | Yes | Sets whether a bloom filter should be written into the file. | -| ENCODING | Yes | Sets the encoding that should be used (e.g. PLAIN or RLE) | -| DICTIONARY_ENABLED | Yes | Sets if dictionary encoding is enabled. Use this instead of ENCODING to set dictionary encoding. | -| STATISTICS_ENABLED | Yes | Sets if statistics are enabled at PAGE or ROW_GROUP level. | -| MAX_STATISTICS_SIZE | Yes | Sets the maximum size in bytes that statistics can take up. | -| BLOOM_FILTER_FPP | Yes | Sets the false positive probability (fpp) for the bloom filter. Implicitly sets BLOOM_FILTER_ENABLED to true. | -| BLOOM_FILTER_NDV | Yes | Sets the number of distinct values (ndv) for the bloom filter. Implicitly sets bloom_filter_enabled to true. | +| Option | Can be Column Specific? | Description | +| ---------------------------- | ----------------------- | ----------------------------------------------------------------------------------------------------------------------------------- | +| COMPRESSION | Yes | Sets the compression codec and if applicable compression level to use | +| MAX_ROW_GROUP_SIZE | No | Sets the maximum number of rows that can be encoded in a single row group. Larger row groups require more memory to write and read. | +| DATA_PAGESIZE_LIMIT | No | Sets the best effort maximum page size in bytes | +| WRITE_BATCH_SIZE | No | Maximum number of rows written for each column in a single batch | +| WRITER_VERSION | No | Parquet writer version (1.0 or 2.0) | +| DICTIONARY_PAGE_SIZE_LIMIT | No | Sets best effort maximum dictionary page size in bytes | +| CREATED_BY | No | Sets the "created by" property in the parquet file | +| COLUMN_INDEX_TRUNCATE_LENGTH | No | Sets the max length of min/max value fields in the column index. | +| DATA_PAGE_ROW_COUNT_LIMIT | No | Sets best effort maximum number of rows in a data page. | +| BLOOM_FILTER_ENABLED | Yes | Sets whether a bloom filter should be written into the file. | +| ENCODING | Yes | Sets the encoding that should be used (e.g. PLAIN or RLE) | +| DICTIONARY_ENABLED | Yes | Sets if dictionary encoding is enabled. Use this instead of ENCODING to set dictionary encoding. | +| STATISTICS_ENABLED | Yes | Sets if statistics are enabled at PAGE or ROW_GROUP level. | +| MAX_STATISTICS_SIZE | Yes | Sets the maximum size in bytes that statistics can take up. | +| BLOOM_FILTER_FPP | Yes | Sets the false positive probability (fpp) for the bloom filter. Implicitly sets BLOOM_FILTER_ENABLED to true. | +| BLOOM_FILTER_NDV | Yes | Sets the number of distinct values (ndv) for the bloom filter. Implicitly sets bloom_filter_enabled to true. | diff --git a/pre-commit.sh b/pre-commit.sh index f82390e229a9..09cf431a1409 100755 --- a/pre-commit.sh +++ b/pre-commit.sh @@ -60,7 +60,10 @@ echo -e "$(GREEN INFO): cargo clippy ..." # Cargo clippy always return exit code 0, and `tee` doesn't work. # So let's just run cargo clippy. -cargo clippy +cargo clippy --all-targets --workspace --features avro,pyarrow -- -D warnings +pushd datafusion-cli +cargo clippy --all-targets --all-features -- -D warnings +popd echo -e "$(GREEN INFO): cargo clippy done" # 2. cargo fmt: format with nightly and stable.