diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 9fd8c75eab236..7c8819d4b4c53 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -144,19 +144,25 @@ where let result = iter .zip(start_array.iter()) .zip(count_array.iter()) - .map(|((string, start), count)| match (string, start, count) { - (Some(string), Some(start), Some(count)) => { - if count < 0 { - exec_err!( + .map(|((string, start), count)| { + match (string, start, count) { + (Some(string), Some(start), Some(count)) => { + if count < 0 { + exec_err!( "negative substring length not allowed: substr(, {start}, {count})" ) - } else { - let skip = max(0, start - 1); - let count = max(0, count + (if start < 1 {start - 1} else {0})); - Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::())) + } else { + let skip_value = start.checked_sub(1); + if skip_value.is_none() { + return exec_err!("negative overflow when calculating skip value"); + } + let skip = max(0, skip_value.unwrap()); + let count = max(0, count + (if start < 1 { start - 1 } else { 0 })); + Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::())) + } } + _ => Ok(None), } - _ => Ok(None), }) .collect::>>()?; @@ -482,6 +488,29 @@ mod tests { Utf8, StringArray ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abc")), + ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), + ], + Ok(Some("abc")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("overflow")), + ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + exec_err!("negative overflow when calculating skip value"), + &str, + Utf8, + StringArray + ); Ok(()) }