diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 8ef139ae6123..7bd3bf716508 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -22,9 +22,10 @@ use std::fmt::Debug; use std::mem::size_of_val; use std::sync::Arc; -use arrow::array::{ArrayRef, AsArray, BooleanArray}; +use arrow::array::{ArrayRef, AsArray, BooleanArray, UInt64Array}; use arrow::compute::{self, lexsort_to_indices, take_arrays, SortColumn}; use arrow::datatypes::{DataType, Field}; +use arrow_schema::SortOptions; use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, @@ -542,6 +543,9 @@ impl LastValueAccumulator { let [value, ordering_values @ ..] = values else { return internal_err!("Empty row in LAST_VALUE"); }; + + let num_rows = value.len(); + if self.requirement_satisfied { // Get last entry according to the order of data: if self.ignore_nulls { @@ -556,7 +560,7 @@ impl LastValueAccumulator { return Ok((!value.is_empty()).then_some(value.len() - 1)); } } - let sort_columns = ordering_values + let mut sort_columns = ordering_values .iter() .zip(self.ordering_req.iter()) .map(|(values, req)| { @@ -569,6 +573,13 @@ impl LastValueAccumulator { }) .collect::>(); + // Order by indices for cases where the values are the same, we expect the last index + let indices: UInt64Array = (0..num_rows).map(|x| x as u64).collect(); + sort_columns.push(SortColumn { + values: Arc::new(indices), + options: Some(!SortOptions::default()), + }); + if self.ignore_nulls { let indices = lexsort_to_indices(&sort_columns, None)?; // If ignoring nulls, find the last non-null value. @@ -607,13 +618,14 @@ impl Accumulator for LastValueAccumulator { } else if let Some(last_idx) = self.get_last_idx(values)? { let row = get_row_at_idx(values, last_idx)?; let orderings = &row[1..]; + // Update when there is a more recent entry if compare_rows( &self.orderings, orderings, &get_sort_options(self.ordering_req.as_ref()), )? - .is_lt() + .is_le() { self.update_with_new_row(&row); } @@ -652,7 +664,7 @@ impl Accumulator for LastValueAccumulator { // version in the new data: if !self.is_set || self.requirement_satisfied - || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt() + || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_le() { // Update with last value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state @@ -701,9 +713,98 @@ fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec Result<()> { + // TODO: Move this kind of test to slt, we don't have a nice way to define the batch size for each `update_batch` + // so there is no trivial way to test this in slt for now + + // test query: select last_value(a order by b) from t1, where b has same value + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + ]); + + // TODO: Cleanup state in `evaluate` or introduce another method? + // We don't have cleanup method to reset the state, so create a new one each time + fn create_acc(schema: &Schema, asc: bool) -> Result { + LastValueAccumulator::try_new( + &DataType::Int64, + &[DataType::Int64], + LexOrdering::new(vec![PhysicalSortExpr::new( + col("b", schema)?, + if asc { + SortOptions::default() + } else { + SortOptions::default().desc() + }, + )]), + false, + ) + } + + let mut last_accumulator = create_acc(&schema, true)?; + let values = vec![ + Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef, // a + Arc::new(Int64Array::from(vec![1, 1, 1])) as ArrayRef, // b + ]; + last_accumulator.update_batch(&values)?; + let values = vec![ + Arc::new(Int64Array::from(vec![4, 5, 6])) as ArrayRef, // a + Arc::new(Int64Array::from(vec![1, 1, 1])) as ArrayRef, // b + ]; + last_accumulator.update_batch(&values)?; + assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(6))); + + let mut last_accumulator = create_acc(&schema, true)?; + let values = vec![ + Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5, 6])) as ArrayRef, // a + Arc::new(Int64Array::from(vec![1, 1, 1, 2, 2, 2])) as ArrayRef, // b + ]; + last_accumulator.update_batch(&values)?; + assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(6))); + + let mut last_accumulator = create_acc(&schema, true)?; + let values = vec![ + Arc::new(Int64Array::from(vec![7, 8, 9])) as ArrayRef, // a + Arc::new(Int64Array::from(vec![2, 2, 2])) as ArrayRef, // b + ]; + last_accumulator.update_batch(&values)?; + assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(9))); + + let mut last_accumulator = create_acc(&schema, true)?; + let states = vec![ + Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])) as ArrayRef, // a + Arc::new(Int64Array::from(vec![1, 2, 2, 1, 1])) as ArrayRef, // order by + Arc::new(BooleanArray::from(vec![true; 5])) as ArrayRef, // is set + ]; + last_accumulator.merge_batch(&states)?; + last_accumulator.merge_batch(&states)?; + assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(3))); + + // desc + let mut last_accumulator = create_acc(&schema, false)?; + let states = vec![ + Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])) as ArrayRef, // a + Arc::new(Int64Array::from(vec![1, 2, 2, 1, 1])) as ArrayRef, // order by + Arc::new(BooleanArray::from(vec![true; 5])) as ArrayRef, // is set + ]; + last_accumulator.merge_batch(&states)?; + let states = vec![ + Arc::new(Int64Array::from(vec![7, 8, 9])) as ArrayRef, // a + Arc::new(Int64Array::from(vec![1, 1, 1])) as ArrayRef, // order by + Arc::new(BooleanArray::from(vec![true; 3])) as ArrayRef, // is set + ]; + last_accumulator.merge_batch(&states)?; + assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(9))); + Ok(()) + } + #[test] fn test_first_last_value_value() -> Result<()> { let mut first_accumulator = FirstValueAccumulator::try_new( diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 056f88450c9f..0d2f581a1922 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -2998,12 +2998,22 @@ physical_plan 05)--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 06)----------MemoryExec: partitions=1, partition_sizes=[1] +query RP +select amount, ts from sales_global; +---- +30 2022-01-01T06:00:00 +50 2022-01-01T08:00:00 +75 2022-01-01T11:30:00 +200 2022-01-02T12:00:00 +100 2022-01-03T10:00:00 +80 2022-01-03T10:00:00 + query RR SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, LAST_VALUE(amount ORDER BY ts ASC) AS fv2 FROM sales_global ---- -30 100 +30 80 # Conversion in between FIRST_VALUE and LAST_VALUE to resolve # contradictory requirements should work in multi partitions.