From 1c01208d16777fa484863eb182e5a309899d9d9e Mon Sep 17 00:00:00 2001 From: Jayjeet Chakraborty Date: Thu, 24 Aug 2023 16:07:49 -0700 Subject: [PATCH 1/7] Initial version of CardinalityAwareRowConverter --- arrow-row/src/lib.rs | 71 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 396f09380af7..3c3f9affa225 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -131,6 +131,7 @@ use std::sync::Arc; use arrow_array::cast::*; use arrow_array::*; +use arrow_array::types::{ArrowDictionaryKeyType, ArrowPrimitiveType}; use arrow_buffer::ArrowNativeType; use arrow_data::ArrayDataBuilder; use arrow_schema::*; @@ -1445,6 +1446,48 @@ unsafe fn decode_column( Ok(array) } +#[derive(Debug)] +pub struct CardinalityAwareRowConverter { + inner: RowConverter, + done: bool, +} + +impl CardinalityAwareRowConverter { + pub fn new(fields: Vec) -> Result { + Ok(Self { + inner: RowConverter::new(fields)?, + done: false, + }) + } + + pub fn convert_rows(&self, rows: &Rows) -> Result, ArrowError> { + self.inner.convert_rows(rows) + } + + pub fn convert_columns( + &mut self, + columns: &[ArrayRef]) -> Result + where + K: ArrowDictionaryKeyType + { + if !self.done { + for (i, col) in columns.iter().enumerate() { + if let DataType::Dictionary(_, _) = col.data_type() { + let cardinality = col.as_dictionary::().values().len(); + println!("cardinality: {}", cardinality); + if cardinality >= 1 { + let mut sort_field = self.inner.fields[i].clone(); + sort_field.preserve_dictionaries = false; + self.inner.codecs[i] = Codec::new(&sort_field).unwrap(); + } + } + } + } + self.done = true; + self.inner.convert_columns(columns) + } +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -1464,6 +1507,34 @@ mod tests { use super::*; #[test] + fn test_cardinality() { + let values = StringArray::from_iter_values(["a", "b", "c"]); + let keys = Int32Array::from(vec![0, 0, 1, 2, 2, 1, 1, 0, 2]); + let a: ArrayRef = Arc::new(DictionaryArray::::try_new(keys, Arc::new(values)).unwrap()); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("a"), + Some("c"), + Some("e"), + Some("g"), + Some("i"), + Some("k"), + Some("m"), + Some("o"), + Some("q"), + ])); + let cols = [a, b]; + let mut converter = CardinalityAwareRowConverter::new(vec![ + SortField::new(DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))), + SortField::new(DataType::Utf8), + ]) + .unwrap(); + let rows = converter.convert_columns::(&cols).unwrap(); + let back = converter.convert_rows(&rows).unwrap(); + println!("{:?}", back); + + + } + fn test_fixed_width() { let cols = [ Arc::new(Int16Array::from_iter([ From af611400dc0b7756ca8dd4199da4a1210a0d8667 Mon Sep 17 00:00:00 2001 From: Jayjeet Chakraborty Date: Thu, 24 Aug 2023 16:12:34 -0700 Subject: [PATCH 2/7] Assume keys are always int32 type --- arrow-row/src/lib.rs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 3c3f9affa225..6137a2c01bce 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -131,7 +131,7 @@ use std::sync::Arc; use arrow_array::cast::*; use arrow_array::*; -use arrow_array::types::{ArrowDictionaryKeyType, ArrowPrimitiveType}; +use arrow_array::types::{ArrowDictionaryKeyType, ArrowPrimitiveType, Int32Type}; use arrow_buffer::ArrowNativeType; use arrow_data::ArrayDataBuilder; use arrow_schema::*; @@ -1464,16 +1464,13 @@ impl CardinalityAwareRowConverter { self.inner.convert_rows(rows) } - pub fn convert_columns( + pub fn convert_columns( &mut self, - columns: &[ArrayRef]) -> Result - where - K: ArrowDictionaryKeyType - { + columns: &[ArrayRef]) -> Result { if !self.done { for (i, col) in columns.iter().enumerate() { if let DataType::Dictionary(_, _) = col.data_type() { - let cardinality = col.as_dictionary::().values().len(); + let cardinality = col.as_any().downcast_ref::>().unwrap().keys().len(); println!("cardinality: {}", cardinality); if cardinality >= 1 { let mut sort_field = self.inner.fields[i].clone(); @@ -1528,7 +1525,7 @@ mod tests { SortField::new(DataType::Utf8), ]) .unwrap(); - let rows = converter.convert_columns::(&cols).unwrap(); + let rows = converter.convert_columns(&cols).unwrap(); let back = converter.convert_rows(&rows).unwrap(); println!("{:?}", back); From e70d4e6bbd3668060268cbca5b7a4cf84cfe6827 Mon Sep 17 00:00:00 2001 From: Jayjeet Chakraborty Date: Thu, 24 Aug 2023 16:17:50 -0700 Subject: [PATCH 3/7] Assume keys are always int32 type --- arrow-row/src/lib.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 6137a2c01bce..b888d0252a70 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -1470,9 +1470,8 @@ impl CardinalityAwareRowConverter { if !self.done { for (i, col) in columns.iter().enumerate() { if let DataType::Dictionary(_, _) = col.data_type() { - let cardinality = col.as_any().downcast_ref::>().unwrap().keys().len(); - println!("cardinality: {}", cardinality); - if cardinality >= 1 { + let cardinality = col.as_any().downcast_ref::>().unwrap().values().len(); + if cardinality >= 10 { let mut sort_field = self.inner.fields[i].clone(); sort_field.preserve_dictionaries = false; self.inner.codecs[i] = Codec::new(&sort_field).unwrap(); @@ -1504,7 +1503,7 @@ mod tests { use super::*; #[test] - fn test_cardinality() { + fn test_card_aware_row_converter() { let values = StringArray::from_iter_values(["a", "b", "c"]); let keys = Int32Array::from(vec![0, 0, 1, 2, 2, 1, 1, 0, 2]); let a: ArrayRef = Arc::new(DictionaryArray::::try_new(keys, Arc::new(values)).unwrap()); @@ -1528,8 +1527,6 @@ mod tests { let rows = converter.convert_columns(&cols).unwrap(); let back = converter.convert_rows(&rows).unwrap(); println!("{:?}", back); - - } fn test_fixed_width() { From 37f4578650ee5a406cb05d1e47d6ab90c602dc7d Mon Sep 17 00:00:00 2001 From: Jayjeet Chakraborty Date: Thu, 24 Aug 2023 16:29:44 -0700 Subject: [PATCH 4/7] Remove unnecessary imports --- arrow-row/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index b888d0252a70..97b01d55ede4 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -131,7 +131,7 @@ use std::sync::Arc; use arrow_array::cast::*; use arrow_array::*; -use arrow_array::types::{ArrowDictionaryKeyType, ArrowPrimitiveType, Int32Type}; +use arrow_array::types::Int32Type; use arrow_buffer::ArrowNativeType; use arrow_data::ArrayDataBuilder; use arrow_schema::*; From c3cdfd355824378e460ee5704f0d7975c708980b Mon Sep 17 00:00:00 2001 From: Jayjeet Chakraborty Date: Thu, 24 Aug 2023 16:45:03 -0700 Subject: [PATCH 5/7] Add a size method --- arrow-row/src/lib.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 97b01d55ede4..dd9cb4b0df3a 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -1460,6 +1460,10 @@ impl CardinalityAwareRowConverter { }) } + pub fn size(&self) -> usize { + self.inner.size() + } + pub fn convert_rows(&self, rows: &Rows) -> Result, ArrowError> { self.inner.convert_rows(rows) } From 70ce44ee0b04e7bbf536ea10e3defc59f04e7fb9 Mon Sep 17 00:00:00 2001 From: Jayjeet Chakraborty Date: Thu, 24 Aug 2023 17:44:50 -0700 Subject: [PATCH 6/7] Support other integer key types --- arrow-row/src/lib.rs | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index dd9cb4b0df3a..28fcfc1227c9 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -131,7 +131,7 @@ use std::sync::Arc; use arrow_array::cast::*; use arrow_array::*; -use arrow_array::types::Int32Type; +use arrow_array::types::*; use arrow_buffer::ArrowNativeType; use arrow_data::ArrayDataBuilder; use arrow_schema::*; @@ -1446,6 +1446,15 @@ unsafe fn decode_column( Ok(array) } +macro_rules! downcast_dict { + ($array:ident, $key:ident) => {{ + $array + .as_any() + .downcast_ref::>() + .unwrap() + }}; +} + #[derive(Debug)] pub struct CardinalityAwareRowConverter { inner: RowConverter, @@ -1473,8 +1482,19 @@ impl CardinalityAwareRowConverter { columns: &[ArrayRef]) -> Result { if !self.done { for (i, col) in columns.iter().enumerate() { - if let DataType::Dictionary(_, _) = col.data_type() { - let cardinality = col.as_any().downcast_ref::>().unwrap().values().len(); + if let DataType::Dictionary(k, _) = col.data_type() { + // let cardinality = col.as_any().downcast_ref::>().unwrap().values().len(); + let cardinality = match k.as_ref() { + DataType::Int8 => downcast_dict!(col, Int32Type).values().len(), + DataType::Int16 => downcast_dict!(col, Int32Type).values().len(), + DataType::Int32 => downcast_dict!(col, Int32Type).values().len(), + DataType::Int64 => downcast_dict!(col, Int64Type).values().len(), + DataType::UInt16 => downcast_dict!(col, UInt16Type).values().len(), + DataType::UInt32 => downcast_dict!(col, UInt32Type).values().len(), + DataType::UInt64 => downcast_dict!(col, UInt64Type).values().len(), + _ => unreachable!(), + }; + if cardinality >= 10 { let mut sort_field = self.inner.fields[i].clone(); sort_field.preserve_dictionaries = false; From b436084a9a5b2a8fa92f9a7929320ecd0d3ecbc1 Mon Sep 17 00:00:00 2001 From: Jayjeet Chakraborty Date: Thu, 24 Aug 2023 17:52:20 -0700 Subject: [PATCH 7/7] Define the low cardinality threshold using a const --- arrow-row/src/lib.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 28fcfc1227c9..a82d79179773 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -1455,6 +1455,8 @@ macro_rules! downcast_dict { }}; } +const LOW_CARDINALITY_THRESHOLD: usize = 10; + #[derive(Debug)] pub struct CardinalityAwareRowConverter { inner: RowConverter, @@ -1495,7 +1497,7 @@ impl CardinalityAwareRowConverter { _ => unreachable!(), }; - if cardinality >= 10 { + if cardinality >= LOW_CARDINALITY_THRESHOLD { let mut sort_field = self.inner.fields[i].clone(); sort_field.preserve_dictionaries = false; self.inner.codecs[i] = Codec::new(&sort_field).unwrap();