Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 24, 2023
1 parent 133ead4 commit e5e3bcb
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 58 deletions.
32 changes: 20 additions & 12 deletions datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,36 @@ pub fn binary_operator_data_type(
let result_type = if !matches!(coerced_type, DataType::Decimal128(_, _)) {
coerced_type
} else {
let lhs_type = if matches!(lhs_type, DataType::Decimal128(_, _) | DataType::Null)
{
lhs_type.clone()
} else {
coerce_numeric_type_to_decimal(lhs_type).ok_or_else(|| {
let lhs_type = match lhs_type {
DataType::Decimal128(_, _) | DataType::Null => lhs_type.clone(),
DataType::Dictionary(_, value_type)
if matches!(**value_type, DataType::Decimal128(_, _)) =>
{
lhs_type.clone()
}
_ => coerce_numeric_type_to_decimal(lhs_type).ok_or_else(|| {
DataFusionError::Internal(format!(
"Could not coerce numeric type to decimal: {:?}",
lhs_type
))
})?
})?,
};
let rhs_type = if matches!(rhs_type, DataType::Decimal128(_, _) | DataType::Null)
{
rhs_type.clone()
} else {
coerce_numeric_type_to_decimal(rhs_type).ok_or_else(|| {

let rhs_type = match rhs_type {
DataType::Decimal128(_, _) | DataType::Null => rhs_type.clone(),
DataType::Dictionary(_, value_type)
if matches!(**value_type, DataType::Decimal128(_, _)) =>
{
rhs_type.clone()
}
_ => coerce_numeric_type_to_decimal(rhs_type).ok_or_else(|| {
DataFusionError::Internal(format!(
"Could not coerce numeric type to decimal: {:?}",
rhs_type
))
})?
})?,
};

match op {
// For Plus and Minus, the result type is the same as the input type which is already promoted
Operator::Plus | Operator::Minus => coerced_type,
Expand Down
1 change: 1 addition & 0 deletions datafusion/jit/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ impl TryFrom<(datafusion_expr::Expr, DFSchemaRef)> for Expr {
left,
op,
right,
..
}) => {
let op = match op {
datafusion_expr::Operator::Eq => BinaryExpr::Eq,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2455,6 +2455,7 @@ mod tests {
left: Box::new(left),
op: Operator::RegexMatch,
right: Box::new(right),
data_type: None,
})
}

Expand All @@ -2463,6 +2464,7 @@ mod tests {
left: Box::new(left),
op: Operator::RegexNotMatch,
right: Box::new(right),
data_type: None,
})
}

Expand All @@ -2471,6 +2473,7 @@ mod tests {
left: Box::new(left),
op: Operator::RegexIMatch,
right: Box::new(right),
data_type: None,
})
}

Expand All @@ -2479,6 +2482,7 @@ mod tests {
left: Box::new(left),
op: Operator::RegexNotIMatch,
right: Box::new(right),
data_type: None,
})
}

Expand Down
40 changes: 31 additions & 9 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1312,8 +1312,9 @@ mod tests {
op: Operator,
r: Arc<dyn PhysicalExpr>,
input_schema: &Schema,
x: &DataType,
) -> Arc<dyn PhysicalExpr> {
binary(l, op, r, input_schema).unwrap()
binary_with_data_type(l, op, r, input_schema, Some(x.clone())).unwrap()
}

#[test]
Expand All @@ -1331,6 +1332,7 @@ mod tests {
Operator::Lt,
col("b", &schema)?,
&schema,
&DataType::Boolean,
);
let batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
Expand Down Expand Up @@ -1364,15 +1366,18 @@ mod tests {
Operator::Lt,
col("b", &schema)?,
&schema,
&DataType::Boolean,
),
Operator::Or,
binary_simple(
col("a", &schema)?,
Operator::Eq,
col("b", &schema)?,
&schema,
&DataType::Boolean,
),
&schema,
&DataType::Boolean,
);
let batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
Expand Down Expand Up @@ -2828,8 +2833,13 @@ mod tests {
op: Operator,
expected: PrimitiveArray<T>,
) -> Result<()> {
let arithmetic_op =
binary_simple(col("a", &schema)?, op, col("b", &schema)?, &schema);
let arithmetic_op = binary_simple(
col("a", &schema)?,
op,
col("b", &schema)?,
&schema,
expected.data_type(),
);
let batch = RecordBatch::try_new(schema, data)?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());

Expand All @@ -2845,7 +2855,8 @@ mod tests {
expected: ArrayRef,
) -> Result<()> {
let lit = Arc::new(Literal::new(literal));
let arithmetic_op = binary_simple(col("a", &schema)?, op, lit, &schema);
let arithmetic_op =
binary_simple(col("a", &schema)?, op, lit, &schema, expected.data_type());
let batch = RecordBatch::try_new(schema, data)?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());

Expand All @@ -2866,7 +2877,8 @@ mod tests {

let left_expr = try_cast(col("a", schema)?, schema, result_type.clone())?;
let right_expr = try_cast(col("b", schema)?, schema, result_type)?;
let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
let arithmetic_op =
binary_simple(left_expr, op, right_expr, schema, &DataType::Boolean);
let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
let batch = RecordBatch::try_new(schema.clone(), data)?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
Expand Down Expand Up @@ -2896,7 +2908,8 @@ mod tests {
try_cast(col("a", schema)?, schema, op_type)?
};

let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
let arithmetic_op =
binary_simple(left_expr, op, right_expr, schema, &DataType::Boolean);
let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
assert_eq!(result.as_ref(), expected);
Expand Down Expand Up @@ -2925,7 +2938,8 @@ mod tests {
try_cast(col("a", schema)?, schema, op_type)?
};

let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
let arithmetic_op =
binary_simple(left_expr, op, right_expr, schema, &DataType::Boolean);
let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
assert_eq!(result.as_ref(), expected);
Expand Down Expand Up @@ -3493,7 +3507,7 @@ mod tests {
let tree_depth: i32 = 100;
let expr = (0..tree_depth)
.map(|_| col("a", schema.as_ref()).unwrap())
.reduce(|l, r| binary_simple(l, Operator::Plus, r, &schema))
.reduce(|l, r| binary_simple(l, Operator::Plus, r, &schema, &DataType::Int32))
.unwrap();

let result = expr
Expand Down Expand Up @@ -4000,7 +4014,13 @@ mod tests {
schema.field(1).is_nullable(),
),
]);
let arithmetic_op = binary_simple(left_expr, op, right_expr, &coerced_schema);
let arithmetic_op = binary_simple(
left_expr,
op,
right_expr,
&coerced_schema,
expected.data_type(),
);
let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
let batch = RecordBatch::try_new(schema.clone(), data)?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
Expand Down Expand Up @@ -4662,6 +4682,7 @@ mod tests {
Operator::GtEq,
lit(ScalarValue::from(25)),
&schema,
&DataType::Boolean,
);

let context = AnalysisContext::from_statistics(&schema, &statistics);
Expand Down Expand Up @@ -4691,6 +4712,7 @@ mod tests {
Operator::GtEq,
a.clone(),
&schema,
&DataType::Boolean,
);

let context = AnalysisContext::from_statistics(&schema, &statistics);
Expand Down
Loading

0 comments on commit e5e3bcb

Please sign in to comment.