From a8a934aa764f20d8ac03f0f7d10bb49708833678 Mon Sep 17 00:00:00 2001 From: cetra3 Date: Wed, 29 Jan 2025 11:07:58 +1030 Subject: [PATCH] Refactor & add tests --- datafusion/sql/src/unparser/expr.rs | 68 ++++++++++++++--------------- 1 file changed, 33 insertions(+), 35 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 67ba3b106dee8..4f344fa5ee1c0 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -33,7 +33,7 @@ use arrow_array::types::{ Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use arrow_array::{Date32Array, Date64Array, PrimitiveArray}; +use arrow_array::{ArrayRef, Date32Array, Date64Array, PrimitiveArray}; use arrow_schema::DataType; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, Column, Result, @@ -528,6 +528,16 @@ impl Unparser<'_> { })) } + fn scalar_value_list_to_sql(&self, array: &ArrayRef) -> Result { + let mut elem = Vec::new(); + for i in 0..array.len() { + let value = ScalarValue::try_from_array(&array, i)?; + elem.push(self.scalar_to_sql(&value)?); + } + + Ok(ast::Expr::Array(Array { elem, named: false })) + } + fn array_element_to_sql(&self, args: &[Expr]) -> Result { if args.len() != 2 { return internal_err!("array_element must have exactly 2 arguments"); @@ -1120,39 +1130,9 @@ impl Unparser<'_> { not_impl_err!("Unsupported scalar: {v:?}") } ScalarValue::LargeBinary(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::FixedSizeList(a) => { - let array = a.values(); - - let mut elem = Vec::new(); - for i in 0..array.len() { - let value = ScalarValue::try_from_array(&array, i)?; - elem.push(self.scalar_to_sql(&value)?); - } - - Ok(ast::Expr::Array(Array { elem, named: true })) - } - ScalarValue::List(a) => { - let array = a.values(); - - let mut elem = Vec::new(); - for i in 0..array.len() { - let value = ScalarValue::try_from_array(&array, i)?; - elem.push(self.scalar_to_sql(&value)?); - } - - Ok(ast::Expr::Array(Array { elem, named: true })) - } - ScalarValue::LargeList(a) => { - let array = a.values(); - - let mut elem = Vec::new(); - for i in 0..array.len() { - let value = ScalarValue::try_from_array(&array, i)?; - elem.push(self.scalar_to_sql(&value)?); - } - - Ok(ast::Expr::Array(Array { elem, named: true })) - } + ScalarValue::FixedSizeList(a) => self.scalar_value_list_to_sql(a.values()), + ScalarValue::List(a) => self.scalar_value_list_to_sql(a.values()), + ScalarValue::LargeList(a) => self.scalar_value_list_to_sql(a.values()), ScalarValue::Date32(Some(_)) => { let date = v .to_array()? @@ -1655,8 +1635,9 @@ mod tests { use std::ops::{Add, Sub}; use std::{any::Any, sync::Arc, vec}; - use arrow::datatypes::TimeUnit; use arrow::datatypes::{Field, Schema}; + use arrow::datatypes::{Int32Type, TimeUnit}; + use arrow_array::ListArray; use arrow_schema::DataType::Int8; use ast::ObjectName; use datafusion_common::TableReference; @@ -2091,6 +2072,23 @@ mod tests { map(vec![lit("a"), lit("b")], vec![lit(1), lit(2)]), "MAP {'a': 1, 'b': 2}", ), + ( + Expr::Literal(ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Utf8(Some("foo".into()))), + )), + "'foo'", + ), + ( + Expr::Literal(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + ))), + "[1, 2, 3]", + ), ]; for (expr, expected) in tests {