diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 396f09380af7..a82d79179773 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -131,6 +131,7 @@ use std::sync::Arc; use arrow_array::cast::*; use arrow_array::*; +use arrow_array::types::*; use arrow_buffer::ArrowNativeType; use arrow_data::ArrayDataBuilder; use arrow_schema::*; @@ -1445,6 +1446,70 @@ unsafe fn decode_column( Ok(array) } +macro_rules! downcast_dict { + ($array:ident, $key:ident) => {{ + $array + .as_any() + .downcast_ref::>() + .unwrap() + }}; +} + +const LOW_CARDINALITY_THRESHOLD: usize = 10; + +#[derive(Debug)] +pub struct CardinalityAwareRowConverter { + inner: RowConverter, + done: bool, +} + +impl CardinalityAwareRowConverter { + pub fn new(fields: Vec) -> Result { + Ok(Self { + inner: RowConverter::new(fields)?, + done: false, + }) + } + + pub fn size(&self) -> usize { + self.inner.size() + } + + pub fn convert_rows(&self, rows: &Rows) -> Result, ArrowError> { + self.inner.convert_rows(rows) + } + + pub fn convert_columns( + &mut self, + columns: &[ArrayRef]) -> Result { + if !self.done { + for (i, col) in columns.iter().enumerate() { + if let DataType::Dictionary(k, _) = col.data_type() { + // let cardinality = col.as_any().downcast_ref::>().unwrap().values().len(); + let cardinality = match k.as_ref() { + DataType::Int8 => downcast_dict!(col, Int32Type).values().len(), + DataType::Int16 => downcast_dict!(col, Int32Type).values().len(), + DataType::Int32 => downcast_dict!(col, Int32Type).values().len(), + DataType::Int64 => downcast_dict!(col, Int64Type).values().len(), + DataType::UInt16 => downcast_dict!(col, UInt16Type).values().len(), + DataType::UInt32 => downcast_dict!(col, UInt32Type).values().len(), + DataType::UInt64 => downcast_dict!(col, UInt64Type).values().len(), + _ => unreachable!(), + }; + + if cardinality >= LOW_CARDINALITY_THRESHOLD { + let mut sort_field = self.inner.fields[i].clone(); + sort_field.preserve_dictionaries = false; + self.inner.codecs[i] = Codec::new(&sort_field).unwrap(); + } + } + } + } + self.done = true; + self.inner.convert_columns(columns) + } +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -1464,6 +1529,32 @@ mod tests { use super::*; #[test] + fn test_card_aware_row_converter() { + let values = StringArray::from_iter_values(["a", "b", "c"]); + let keys = Int32Array::from(vec![0, 0, 1, 2, 2, 1, 1, 0, 2]); + let a: ArrayRef = Arc::new(DictionaryArray::::try_new(keys, Arc::new(values)).unwrap()); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("a"), + Some("c"), + Some("e"), + Some("g"), + Some("i"), + Some("k"), + Some("m"), + Some("o"), + Some("q"), + ])); + let cols = [a, b]; + let mut converter = CardinalityAwareRowConverter::new(vec![ + SortField::new(DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))), + SortField::new(DataType::Utf8), + ]) + .unwrap(); + let rows = converter.convert_columns(&cols).unwrap(); + let back = converter.convert_rows(&rows).unwrap(); + println!("{:?}", back); + } + fn test_fixed_width() { let cols = [ Arc::new(Int16Array::from_iter([