Skip to content

Commit

Permalink
remove type coercion in the binary physical expr (#3396)
Browse files Browse the repository at this point in the history
* remove type coercion binary from phy

* fix test case

* revert the fix for #3387

* type coercion before simplify expression

* complete remove the type coercion in the physical plan

* refactor

* merge master

* refactor

* do type coercion in the simplify expression

* Add comments

* fix: fmt

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
liukun4515 and alamb authored Sep 24, 2022
1 parent b625277 commit d7c0e42
Show file tree
Hide file tree
Showing 13 changed files with 529 additions and 252 deletions.
4 changes: 4 additions & 0 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1454,7 +1454,11 @@ impl SessionState {
rules.push(Arc::new(FilterNullJoinKeys::default()));
}
rules.push(Arc::new(ReduceOuterJoin::new()));
// TODO: https://github.com/apache/arrow-datafusion/issues/3557
// remove this, after the issue fixed.
rules.push(Arc::new(TypeCoercion::new()));
// after the type coercion, can do simplify expression again
rules.push(Arc::new(SimplifyExpressions::new()));
rules.push(Arc::new(FilterPushDown::new()));
rules.push(Arc::new(LimitPushDown::new()));
rules.push(Arc::new(SingleDistinctToGroupBy::new()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ mod tests {
use arrow::array::{Int32Array, Int64Array};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_physical_expr::expressions::cast;
use datafusion_physical_expr::PhysicalExpr;

use crate::error::Result;
Expand Down Expand Up @@ -525,7 +526,7 @@ mod tests {
expressions::binary(
expressions::col("a", &schema)?,
Operator::Gt,
expressions::lit(1u32),
cast(expressions::lit(1u32), &schema, DataType::Int32)?,
&schema,
)?,
source,
Expand Down Expand Up @@ -568,7 +569,7 @@ mod tests {
expressions::binary(
expressions::col("a", &schema)?,
Operator::Gt,
expressions::lit(1u32),
cast(expressions::lit(1u32), &schema, DataType::Int32)?,
&schema,
)?,
source,
Expand Down
22 changes: 15 additions & 7 deletions datafusion/core/src/physical_plan/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -871,13 +871,14 @@ mod tests {
physical_plan::collect,
};
use arrow::array::Float32Array;
use arrow::datatypes::DataType::Decimal128;
use arrow::record_batch::RecordBatch;
use arrow::{
array::{Int64Array, Int8Array, StringArray},
datatypes::{DataType, Field},
};
use chrono::{TimeZone, Utc};
use datafusion_expr::{col, lit};
use datafusion_expr::{cast, col, lit};
use futures::StreamExt;
use object_store::local::LocalFileSystem;
use object_store::path::Path;
Expand Down Expand Up @@ -1768,6 +1769,7 @@ mod tests {
// In this case, construct four types of statistics to filtered with the decimal predication.

// INT32: c1 > 5, the c1 is decimal(9,2)
// The type of scalar value if decimal(9,2), don't need to do cast
let expr = col("c1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2)));
let schema =
Schema::new(vec![Field::new("c1", DataType::Decimal128(9, 2), false)]);
Expand Down Expand Up @@ -1809,11 +1811,15 @@ mod tests {
);

// INT32: c1 > 5, but parquet decimal type has different precision or scale to arrow decimal
// The c1 type is decimal(9,0) in the parquet file, and the type of scalar is decimal(5,2).
// We should convert all type to the coercion type, which is decimal(11,2)
// The decimal of arrow is decimal(5,2), the decimal of parquet is decimal(9,0)
let expr = col("c1").gt(lit(ScalarValue::Decimal128(Some(500), 5, 2)));
let expr = cast(col("c1"), DataType::Decimal128(11, 2)).gt(cast(
lit(ScalarValue::Decimal128(Some(500), 5, 2)),
Decimal128(11, 2),
));
let schema =
Schema::new(vec![Field::new("c1", DataType::Decimal128(5, 2), false)]);
// The decimal of parquet is decimal(9,0)
Schema::new(vec![Field::new("c1", DataType::Decimal128(9, 0), false)]);
let schema_descr = get_test_schema_descr(vec![(
"c1",
PhysicalType::INT32,
Expand Down Expand Up @@ -1901,11 +1907,13 @@ mod tests {
vec![1]
);

// FIXED_LENGTH_BYTE_ARRAY: c1 = 100, the c1 is decimal(28,2)
// FIXED_LENGTH_BYTE_ARRAY: c1 = decimal128(100000, 28, 3), the c1 is decimal(18,2)
// the type of parquet is decimal(18,2)
let expr = col("c1").eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3)));
let schema =
Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 3), false)]);
Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]);
// cast the type of c1 to decimal(28,3)
let left = cast(col("c1"), DataType::Decimal128(28, 3));
let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3)));
let schema_descr = get_test_schema_descr(vec![(
"c1",
PhysicalType::FIXED_LEN_BYTE_ARRAY,
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1685,7 +1685,7 @@ mod tests {
use crate::execution::runtime_env::RuntimeEnv;
use crate::logical_plan::plan::Extension;
use crate::physical_plan::{
expressions, DisplayFormatType, Partitioning, Statistics,
expressions, DisplayFormatType, Partitioning, PhysicalPlanner, Statistics,
};
use crate::prelude::{SessionConfig, SessionContext};
use crate::scalar::ScalarValue;
Expand Down Expand Up @@ -1736,10 +1736,10 @@ mod tests {
let exec_plan = plan(&logical_plan).await?;

// verify that the plan correctly casts u8 to i64
// the cast from u8 to i64 for literal will be simplified, and get lit(int64(5))
// the cast here is implicit so has CastOptions with safe=true
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) } }";
assert!(format!("{:?}", exec_plan).contains(expected));

Ok(())
}

Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1834,11 +1834,11 @@ async fn aggregate_avg_add() -> Result<()> {
assert_eq!(results.len(), 1);

let expected = vec![
"+--------------+---------------------------+---------------------------+---------------------------+",
"| AVG(test.c1) | AVG(test.c1) + Float64(1) | AVG(test.c1) + Float64(2) | Float64(1) + AVG(test.c1) |",
"+--------------+---------------------------+---------------------------+---------------------------+",
"| 1.5 | 2.5 | 3.5 | 2.5 |",
"+--------------+---------------------------+---------------------------+---------------------------+",
"+--------------+-------------------------+-------------------------+-------------------------+",
"| AVG(test.c1) | AVG(test.c1) + Int64(1) | AVG(test.c1) + Int64(2) | Int64(1) + AVG(test.c1) |",
"+--------------+-------------------------+-------------------------+-------------------------+",
"| 1.5 | 2.5 | 3.5 | 2.5 |",
"+--------------+-------------------------+-------------------------+-------------------------+",
];
assert_batches_sorted_eq!(expected, &results);

Expand Down
114 changes: 57 additions & 57 deletions datafusion/core/tests/sql/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,25 +376,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+----------------------------------------------------+",
"| decimal_simple.c1 + Decimal128(Some(1000000),27,6) |",
"+----------------------------------------------------+",
"| 1.000010 |",
"| 1.000020 |",
"| 1.000020 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"+----------------------------------------------------+",
"+------------------------------+",
"| decimal_simple.c1 + Int64(1) |",
"+------------------------------+",
"| 1.000010 |",
"| 1.000020 |",
"| 1.000020 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"+------------------------------+",
];
assert_batches_eq!(expected, &actual);
// array decimal(10,6) + array decimal(12,7) => decimal(13,7)
Expand Down Expand Up @@ -434,25 +434,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+----------------------------------------------------+",
"| decimal_simple.c1 - Decimal128(Some(1000000),27,6) |",
"+----------------------------------------------------+",
"| -0.999990 |",
"| -0.999980 |",
"| -0.999980 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"+----------------------------------------------------+",
"+------------------------------+",
"| decimal_simple.c1 - Int64(1) |",
"+------------------------------+",
"| -0.999990 |",
"| -0.999980 |",
"| -0.999980 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"+------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down Expand Up @@ -492,25 +492,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+-----------------------------------------------------+",
"| decimal_simple.c1 * Decimal128(Some(20000000),31,6) |",
"+-----------------------------------------------------+",
"| 0.000200 |",
"| 0.000400 |",
"| 0.000400 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"+-----------------------------------------------------+",
"+-------------------------------+",
"| decimal_simple.c1 * Int64(20) |",
"+-------------------------------+",
"| 0.000200 |",
"| 0.000400 |",
"| 0.000400 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"+-------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down
13 changes: 6 additions & 7 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ async fn csv_in_set_test() -> Result<()> {

#[tokio::test]
async fn multiple_or_predicates() -> Result<()> {
// TODO https://github.com/apache/arrow-datafusion/issues/3587
let ctx = SessionContext::new();
register_tpch_csv(&ctx, "lineitem").await?;
register_tpch_csv(&ctx, "part").await?;
Expand Down Expand Up @@ -424,15 +425,13 @@ async fn multiple_or_predicates() -> Result<()> {
let plan = state.optimize(&plan)?;
// Note that we expect `#part.p_partkey = #lineitem.l_partkey` to have been
// factored out and appear only once in the following plan
let expected =vec![
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #lineitem.l_partkey [l_partkey:Int64]",
" Projection: #part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_size >= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Inner Join: #lineitem.l_partkey = #part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" Filter: #part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[#part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND CAST(#part.p_size AS Int64) BETWEEN Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND CAST(#part.p_size AS Int64) BETWEEN Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND CAST(#part.p_size AS Int64) BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Inner Join: #lineitem.l_partkey = #part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,12 +523,12 @@ async fn use_between_expression_in_select_query() -> Result<()> {
.unwrap()
.to_string();

// TODO https://github.com/apache/arrow-datafusion/issues/3587
// Only test that the projection exprs are correct, rather than entire output
let needle = "ProjectionExec: expr=[c1@0 >= 2 AND c1@0 <= 3 as test.c1 BETWEEN Int64(2) AND Int64(3)]";
assert_contains!(&formatted, needle);
let needle = "Projection: #test.c1 >= Int64(2) AND #test.c1 <= Int64(3)";
let needle = "Projection: #test.c1 BETWEEN Int64(2) AND Int64(3)";
assert_contains!(&formatted, needle);

Ok(())
}

Expand Down
Loading

0 comments on commit d7c0e42

Please sign in to comment.