Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cubesql): Implement format and col_description #9072

Merged
merged 4 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 177 additions & 11 deletions rust/cubesql/cubesql/src/compile/engine/udf/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3087,10 +3087,18 @@
Some(as_str) => {
match PgType::get_all().iter().find(|e| e.typname == as_str) {
None => {
return Err(DataFusionError::Execution(format!(
"Unable to cast expression to Regclass: Unknown type: {}",
as_str
)))
// If the type name contains a dot, it's a schema-qualified name
// and we should return the approprate RegClass to be converted to OID
// For now, we'll return 0 so metabase can sync without failing
// TODO actually read `pg_type`
if as_str.contains('.') {
builder.append_value(0)?;
} else {
return Err(DataFusionError::Execution(format!(
"Unable to cast expression to Regclass: Unknown type: {}",
as_str
)));

Check warning on line 3100 in rust/cubesql/cubesql/src/compile/engine/udf/common.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/udf/common.rs#L3097-L3100

Added lines #L3097 - L3100 were not covered by tests
}
}
Some(ty) => {
builder.append_value(ty.oid as i64)?;
Expand Down Expand Up @@ -3148,6 +3156,171 @@
)
}

// Return a NOOP for this so metabase can sync without failing
// See https://www.postgresql.org/docs/17/functions-info.html#FUNCTIONS-INFO-COMMENT here
// TODO: Implement this
pub fn create_col_description_udf() -> ScalarUDF {
let fun = make_scalar_function(move |args: &[ArrayRef]| {
// Ensure the output array has the same length as the input
let input_length = args[0].len();
let mut builder = StringBuilder::new(input_length);

for _ in 0..input_length {
builder.append_null()?;
}

Ok(Arc::new(builder.finish()) as ArrayRef)
});

let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));

ScalarUDF::new(
"col_description",
// Correct signature for col_description should be `(oid, integer) → text`
// We model oid as UInt32, so [DataType::UInt32, DataType::Int32] is a proper arguments
// However, it seems that coercion rules in DF differs from PostgreSQL at the moment
// And metabase uses col_description(CAST(CAST(... AS regclass) AS oid), cardinal_number)
// And we model regclass as Int64, and cardinal_number as UInt32
// Which is why second signature is necessary
&Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::UInt32, DataType::Int32]),
// TODO remove this signature in favor of proper model/coercion
TypeSignature::Exact(vec![DataType::Int64, DataType::UInt32]),
],
Volatility::Stable,
),
&return_type,
&fun,
)
}

// See https://www.postgresql.org/docs/17/functions-string.html#FUNCTIONS-STRING-FORMAT
pub fn create_format_udf() -> ScalarUDF {
let fun = make_scalar_function(move |args: &[ArrayRef]| {
// Ensure at least one argument is provided
if args.is_empty() {
return Err(DataFusionError::Execution(
"format() requires at least one argument".to_string(),
));

Check warning on line 3205 in rust/cubesql/cubesql/src/compile/engine/udf/common.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/udf/common.rs#L3203-L3205

Added lines #L3203 - L3205 were not covered by tests
}

// Ensure the first argument is a Utf8 (string)
if args[0].data_type() != &DataType::Utf8 {
return Err(DataFusionError::Execution(
"format() first argument must be a string".to_string(),
));

Check warning on line 3212 in rust/cubesql/cubesql/src/compile/engine/udf/common.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/udf/common.rs#L3210-L3212

Added lines #L3210 - L3212 were not covered by tests
}

let format_strings = downcast_string_arg!(&args[0], "format_str", i32);
let mut builder = StringBuilder::new(format_strings.len());

for i in 0..format_strings.len() {
if format_strings.is_null(i) {
builder.append_null()?;
continue;

Check warning on line 3221 in rust/cubesql/cubesql/src/compile/engine/udf/common.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/udf/common.rs#L3220-L3221

Added lines #L3220 - L3221 were not covered by tests
}

let format_str = format_strings.value(i);
let mut result = String::new();
let mut format_chars = format_str.chars().peekable();
let mut arg_index = 1; // Start from first argument after format string

while let Some(c) = format_chars.next() {
if c != '%' {
result.push(c);
continue;
}

match format_chars.next() {
Some('I') => {
// Handle %I - SQL identifier
if arg_index >= args.len() {
return Err(DataFusionError::Execution(
"Not enough arguments for format string".to_string(),
));

Check warning on line 3241 in rust/cubesql/cubesql/src/compile/engine/udf/common.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/udf/common.rs#L3239-L3241

Added lines #L3239 - L3241 were not covered by tests
}

let arg = &args[arg_index];
let value = match arg.data_type() {
DataType::Utf8 => {
let str_arr = downcast_string_arg!(arg, "arg", i32);
if str_arr.is_null(i) {
return Err(DataFusionError::Execution(
"NULL values cannot be formatted as identifiers"
.to_string(),
));

Check warning on line 3252 in rust/cubesql/cubesql/src/compile/engine/udf/common.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/udf/common.rs#L3249-L3252

Added lines #L3249 - L3252 were not covered by tests
}
str_arr.value(i).to_string()
}
_ => {
// For other types, try to convert to string
let str_arr = cast(&arg, &DataType::Utf8)?;
let str_arr =
str_arr.as_any().downcast_ref::<StringArray>().unwrap();
if str_arr.is_null(i) {
return Err(DataFusionError::Execution(
"NULL values cannot be formatted as identifiers"
.to_string(),
));
}
str_arr.value(i).to_string()

Check warning on line 3267 in rust/cubesql/cubesql/src/compile/engine/udf/common.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/udf/common.rs#L3258-L3267

Added lines #L3258 - L3267 were not covered by tests
}
};

// Quote any identifier for now
// That's a safety-first approach: it would quote too much, but every edge-case would be covered
// Like `1` or `1a` or `select`
// TODO Quote identifier only if necessary
let needs_quoting = true;

if needs_quoting {
result.push('"');
result.push_str(&value.replace('"', "\"\""));
result.push('"');
} else {
result.push_str(&value);
}

Check warning on line 3283 in rust/cubesql/cubesql/src/compile/engine/udf/common.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/udf/common.rs#L3282-L3283

Added lines #L3282 - L3283 were not covered by tests
arg_index += 1;
}
Some('%') => {
// %% is escaped to single %
result.push('%');
}

Check warning on line 3289 in rust/cubesql/cubesql/src/compile/engine/udf/common.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/udf/common.rs#L3286-L3289

Added lines #L3286 - L3289 were not covered by tests
Some(c) => {
return Err(DataFusionError::Execution(format!(
"Unsupported format specifier %{}",
c
)));
}
None => {
return Err(DataFusionError::Execution(
"Invalid format string - ends with %".to_string(),
));
}
}
}

builder.append_value(result)?;
}

Ok(Arc::new(builder.finish()) as ArrayRef)
});

let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));

ScalarUDF::new(
"format",
// Actually, format should be variadic with types (Utf8, any*)
// But ATM DataFusion does not support those signatures
// And this would work through implicit casting to Utf8
// TODO migrate to proper custom signature once it's supported by DF
&Signature::variadic(vec![DataType::Utf8], Volatility::Immutable),
&return_type,
&fun,
)
}

pub fn create_json_build_object_udf() -> ScalarUDF {
let fun = make_scalar_function(move |_args: &[ArrayRef]| {
// TODO: Implement
Expand Down Expand Up @@ -3769,13 +3942,6 @@
rettyp = TimestampTz,
vol = Volatile
);
register_fun_stub!(
udf,
"col_description",
tsig = [Oid, Int32],
rettyp = Utf8,
vol = Stable
);
register_fun_stub!(udf, "convert", tsig = [Binary, Utf8, Utf8], rettyp = Binary);
register_fun_stub!(udf, "convert_from", tsig = [Binary, Utf8], rettyp = Utf8);
register_fun_stub!(udf, "convert_to", tsig = [Utf8, Utf8], rettyp = Binary);
Expand Down
54 changes: 54 additions & 0 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16211,4 +16211,58 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),

Ok(())
}

#[tokio::test]
async fn test_format_function() -> Result<(), CubeError> {
// Test: Basic usage with a single identifier
let result = execute_query(
"SELECT format('%I', 'column_name') AS formatted_identifier".to_string(),
DatabaseProtocol::PostgreSQL,
)
.await?;
insta::assert_snapshot!("formatted_identifier", result);

// Test: Using multiple identifiers
let result = execute_query(
"SELECT format('%I, %I', 'table_name', 'column_name') AS formatted_identifiers"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await?;
insta::assert_snapshot!("formatted_identifiers", result);

// Test: Unsupported format specifier
let result = execute_query(
"SELECT format('%X', 'value') AS unsupported_specifier".to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;
assert!(result.is_err());

// Test: Format string ending with %
let result = execute_query(
"SELECT format('%', 'value') AS invalid_format".to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;
assert!(result.is_err());

// Test: Quoting necessary for special characters
let result = execute_query(
"SELECT format('%I', 'column-name') AS quoted_identifier".to_string(),
DatabaseProtocol::PostgreSQL,
)
.await?;
insta::assert_snapshot!("quoted_identifier", result);

// Test: Quoting necessary for reserved keywords
let result = execute_query(
"SELECT format('%I', 'select') AS quoted_keyword".to_string(),
DatabaseProtocol::PostgreSQL,
)
.await?;
insta::assert_snapshot!("quoted_keyword", result);

Ok(())
}
}
2 changes: 2 additions & 0 deletions rust/cubesql/cubesql/src/compile/query_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,9 @@ impl QueryEngine for SqlQueryEngine {
ctx.register_udf(create_current_timestamp_udf("localtimestamp"));
ctx.register_udf(create_current_schema_udf());
ctx.register_udf(create_current_schemas_udf());
ctx.register_udf(create_format_udf());
ctx.register_udf(create_format_type_udf());
ctx.register_udf(create_col_description_udf());
ctx.register_udf(create_pg_datetime_precision_udf());
ctx.register_udf(create_pg_numeric_precision_udf());
ctx.register_udf(create_pg_numeric_scale_udf());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
source: cubesql/src/compile/mod.rs
expression: result
---
+----------------------+
| formatted_identifier |
+----------------------+
| "column_name" |
+----------------------+
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
source: cubesql/src/compile/mod.rs
expression: result
---
+-----------------------------+
| formatted_identifiers |
+-----------------------------+
| "table_name", "column_name" |
+-----------------------------+
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
source: cubesql/src/compile/mod.rs
expression: result
---
+-------------------+
| quoted_identifier |
+-------------------+
| "column-name" |
+-------------------+
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
source: cubesql/src/compile/mod.rs
expression: result
---
+----------------+
| quoted_keyword |
+----------------+
| "select" |
+----------------+
Loading
Loading