From 4a22b1537e6e1f7f0ae0475f9df27fede392943c Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 26 Sep 2024 17:23:10 -0400 Subject: [PATCH] feat: finish implementing deserialiser for record batch --- rig-lancedb/src/utils/deserializer.rs | 1054 ++++++++++++++++++------- 1 file changed, 774 insertions(+), 280 deletions(-) diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs index bbe915b5..8eace4f4 100644 --- a/rig-lancedb/src/utils/deserializer.rs +++ b/rig-lancedb/src/utils/deserializer.rs @@ -1,18 +1,18 @@ use std::sync::Arc; use arrow_array::{ + cast::AsArray, types::{ - BinaryType, ByteArrayType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + ArrowDictionaryKeyType, BinaryType, ByteArrayType, Date32Type, Date64Type, Decimal128Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, - Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, - LargeBinaryType, LargeUtf8Type, Time32MillisecondType, Time32SecondType, - Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, Utf8Type, + DurationSecondType, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, + IntervalYearMonthType, LargeBinaryType, LargeUtf8Type, RunEndIndexType, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, Utf8Type, }, - Array, ArrowPrimitiveType, FixedSizeBinaryArray, FixedSizeListArray, GenericByteArray, - GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatch, StructArray, + Array, ArrowPrimitiveType, OffsetSizeTrait, RecordBatch, RunArray, StructArray, UnionArray, }; use lancedb::arrow::arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; use rig::vector_store::VectorStoreError; @@ -25,179 +25,291 @@ fn arrow_to_rig_error(e: ArrowError) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } -trait Test { +pub trait RecordBatchDeserializer { fn deserialize(&self) -> Result; } -impl Test for RecordBatch { +impl RecordBatchDeserializer for RecordBatch { fn deserialize(&self) -> Result { fn type_matcher(column: &Arc) -> Result, VectorStoreError> { match column.data_type() { DataType::Null => Ok(vec![serde_json::Value::Null]), - // f16 does not implement serde_json::Deserialize. Need to cast to f32. - DataType::Float16 => column - .to_primitive::() - .map_err(arrow_to_rig_error)? - .iter() - .map(|float_16| serde_json::to_value(float_16.to_f32())) - .collect::, _>>() + DataType::Float32 => column + .to_primitive_value::() .map_err(serde_to_rig_error), - DataType::Float32 => column.to_primitive_value::(), - DataType::Float64 => column.to_primitive_value::(), - DataType::Int8 => column.to_primitive_value::(), - DataType::Int16 => column.to_primitive_value::(), - DataType::Int32 => column.to_primitive_value::(), - DataType::Int64 => column.to_primitive_value::(), - DataType::UInt8 => column.to_primitive_value::(), - DataType::UInt16 => column.to_primitive_value::(), - DataType::UInt32 => column.to_primitive_value::(), - DataType::UInt64 => column.to_primitive_value::(), - DataType::Date32 => column.to_primitive_value::(), - DataType::Date64 => column.to_primitive_value::(), - DataType::Decimal128(..) => column.to_primitive_value::(), - // i256 does not implement serde_json::Deserialize. Need to cast to i128. - DataType::Decimal256(..) => column - .to_primitive::() - .map_err(arrow_to_rig_error)? - .iter() - .map(|dec_256| serde_json::to_value(dec_256.as_i128())) - .collect::, _>>() + DataType::Float64 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Int8 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Int16 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Int32 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Int64 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::UInt8 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::UInt16 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::UInt32 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::UInt64 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Date32 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Date64 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Decimal128(..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Time32(TimeUnit::Second) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Time32(TimeUnit::Millisecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Time64(TimeUnit::Microsecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Time64(TimeUnit::Nanosecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Timestamp(TimeUnit::Microsecond, ..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Timestamp(TimeUnit::Millisecond, ..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Timestamp(TimeUnit::Second, ..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Timestamp(TimeUnit::Nanosecond, ..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Duration(TimeUnit::Microsecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Duration(TimeUnit::Millisecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Duration(TimeUnit::Nanosecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Duration(TimeUnit::Second) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Interval(IntervalUnit::YearMonth) => column + .to_primitive_value::() .map_err(serde_to_rig_error), - DataType::Time32(TimeUnit::Second) => { - column.to_primitive_value::() - } - DataType::Time32(TimeUnit::Millisecond) => { - column.to_primitive_value::() - } - DataType::Time64(TimeUnit::Microsecond) => { - column.to_primitive_value::() - } - DataType::Time64(TimeUnit::Nanosecond) => { - column.to_primitive_value::() - } - DataType::Timestamp(TimeUnit::Microsecond, ..) => { - column.to_primitive_value::() - } - DataType::Timestamp(TimeUnit::Millisecond, ..) => { - column.to_primitive_value::() - } - DataType::Timestamp(TimeUnit::Second, ..) => { - column.to_primitive_value::() - } - DataType::Timestamp(TimeUnit::Nanosecond, ..) => { - column.to_primitive_value::() - } - DataType::Duration(TimeUnit::Microsecond) => { - column.to_primitive_value::() - } - DataType::Duration(TimeUnit::Millisecond) => { - column.to_primitive_value::() - } - DataType::Duration(TimeUnit::Nanosecond) => { - column.to_primitive_value::() - } - DataType::Duration(TimeUnit::Second) => { - column.to_primitive_value::() - } DataType::Interval(IntervalUnit::DayTime) => Ok(column .to_primitive::() - .map_err(arrow_to_rig_error)? .iter() - .map(|interval| { + .map(|IntervalDayTime { days, milliseconds }| { json!({ - "days": interval.days, - "milliseconds": interval.milliseconds, + "days": days, + "milliseconds": milliseconds, }) }) .collect()), DataType::Interval(IntervalUnit::MonthDayNano) => Ok(column .to_primitive::() - .map_err(arrow_to_rig_error)? .iter() - .map(|interval| { - json!({ - "months": interval.months, - "days": interval.days, - "nanoseconds": interval.nanoseconds, - }) - }) + .map( + |IntervalMonthDayNano { + months, + days, + nanoseconds, + }| { + json!({ + "months": months, + "days": days, + "nanoseconds": nanoseconds, + }) + }, + ) .collect()), - DataType::Interval(IntervalUnit::YearMonth) => { - column.to_primitive_value::() + DataType::Utf8 => column + .to_str_value::() + .map_err(serde_to_rig_error), + DataType::LargeUtf8 => column + .to_str_value::() + .map_err(serde_to_rig_error), + DataType::Binary => column + .to_str_value::() + .map_err(serde_to_rig_error), + DataType::LargeBinary => column + .to_str_value::() + .map_err(serde_to_rig_error), + DataType::FixedSizeBinary(n) => (0..*n) + .map(|i| serde_json::to_value(column.as_fixed_size_binary().value(i as usize))) + .collect::, _>>() + .map_err(serde_to_rig_error), + DataType::Boolean => { + let bool_array = column.as_boolean(); + (0..bool_array.len()) + .map(|i| bool_array.value(i)) + .map(serde_json::to_value) + .collect::, _>>() + .map_err(serde_to_rig_error) } - DataType::Utf8 | DataType::Utf8View => column.to_str_value::(), - DataType::LargeUtf8 => column.to_str_value::(), - DataType::Binary => column.to_str_value::(), - DataType::LargeBinary => column.to_str_value::(), - DataType::FixedSizeBinary(n) => { - match column.as_any().downcast_ref::() { - Some(list_array) => (0..*n) - .map(|j| serde_json::to_value(list_array.value(j as usize))) - .collect::, _>>() - .map_err(serde_to_rig_error), - None => Err(VectorStoreError::DatastoreError(Box::new( - ArrowError::CastError(format!( - "Can't cast column {column:?} to fixed size list array" - )), - ))), - } + DataType::FixedSizeList(..) => { + column.to_fixed_lists().iter().map(type_matcher).map_ok() } - DataType::FixedSizeList(..) => column - .fixed_nested_lists() - .map_err(arrow_to_rig_error)? - .iter() - .map(|nested_list| type_matcher(nested_list)) - .map_ok(), - DataType::List(..) | DataType::ListView(..) => column - .nested_lists::() - .map_err(arrow_to_rig_error)? - .iter() - .map(|nested_list| type_matcher(nested_list)) - .map_ok(), - DataType::LargeList(..) | DataType::LargeListView(..) => column - .nested_lists::() - .map_err(arrow_to_rig_error)? - .iter() - .map(|nested_list| type_matcher(nested_list)) - .map_ok(), - DataType::Struct(..) => match column.as_any().downcast_ref::() { - Some(struct_array) => struct_array - .nested_lists() - .iter() - .map(|nested_list| type_matcher(nested_list)) - .map_ok(), + DataType::List(..) => column.to_list::().iter().map(type_matcher).map_ok(), + DataType::LargeList(..) => { + column.to_list::().iter().map(type_matcher).map_ok() + } + DataType::Struct(..) => { + let struct_array = column.as_struct(); + let struct_columns = struct_array + .inner_lists() + .iter() + .map(type_matcher) + .collect::, _>>()?; + + Ok(struct_columns + .build_struct(struct_array.num_rows(), struct_array.column_names())) + } + DataType::Map(..) => { + let map_columns = column + .as_map() + .entries() + .inner_lists() + .iter() + .map(type_matcher) + .collect::, _>>()?; + + Ok(map_columns.build_map()) + } + DataType::Dictionary(keys_type, ..) => { + let (keys, v) = match **keys_type { + DataType::Int8 => column.to_dict_values::()?, + DataType::Int16 => column.to_dict_values::()?, + DataType::Int32 => column.to_dict_values::()?, + DataType::Int64 => column.to_dict_values::()?, + DataType::UInt8 => column.to_dict_values::()?, + DataType::UInt16 => column.to_dict_values::()?, + DataType::UInt32 => column.to_dict_values::()?, + DataType::UInt64 => column.to_dict_values::()?, + _ => { + return Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!( + "Dictionary keys type is not accepted: {keys_type:?}" + )), + ))) + } + }; + + let values = type_matcher(v)?; + + Ok(keys + .iter() + .zip(values) + .map(|(k, v)| { + let mut map = serde_json::Map::new(); + map.insert(k.to_string(), v); + map + }) + .map(Value::Object) + .collect()) + } + DataType::Union(..) => match column.as_any().downcast_ref::() { + Some(union_array) => (0..union_array.len()) + .map(|i| union_array.value(i).clone()) + .collect::>() + .iter() + .map(type_matcher) + .map_ok(), None => Err(VectorStoreError::DatastoreError(Box::new( ArrowError::CastError(format!( - "Can't cast array: {column:?} to struct array" + "Can't cast column {column:?} to union array" )), ))), }, - // DataType::Map(..) => { - // let item = match column.as_any().downcast_ref::() { - // Some(map_array) => map_array - // .entries() - // .nested_lists() - // .iter() - // .map(|nested_list| type_matcher(nested_list, nested_list.data_type())) - // .collect::, _>>(), - // None => Err(VectorStoreError::DatastoreError(Box::new( - // ArrowError::CastError(format!( - // "Can't cast array: {column:?} to map array" - // )), - // ))), - // }?; - // } - // DataType::Dictionary(key_data_type, value_data_type) => { - // let item = match column.as_any().downcast_ref::() { - // Some(map_array) => { - // let keys = &Arc::new(map_array.keys()); - // type_matcher(keys, keys.data_type()) - // } - // None => Err(ArrowError::CastError(format!( - // "Can't cast array: {column:?} to map array" - // ))), - // }?; - // }, + DataType::RunEndEncoded(counter_type, ..) => { + let items: Vec> = match counter_type.data_type() { + DataType::Int16 => { + let (counter, v) = column + .to_run_end::() + .map_err(arrow_to_rig_error)?; + + counter + .into_iter() + .zip(type_matcher(&v)?) + .map(|(n, value)| vec![value; n as usize]) + .collect() + } + DataType::Int32 => { + let (counter, v) = column + .to_run_end::() + .map_err(arrow_to_rig_error)?; + + counter + .into_iter() + .zip(type_matcher(&v)?) + .map(|(n, value)| vec![value; n as usize]) + .collect() + } + DataType::Int64 => { + let (counter, v) = column + .to_run_end::() + .map_err(arrow_to_rig_error)?; + + counter + .into_iter() + .zip(type_matcher(&v)?) + .map(|(n, value)| vec![value; n as usize]) + .collect() + } + _ => { + return Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!( + "RunEndEncoded index type is not accepted: {counter_type:?}" + )), + ))) + } + }; + + items + .iter() + .map(|item| serde_json::to_value(item).map_err(serde_to_rig_error)) + .collect() + } + // Not yet fully supported + DataType::BinaryView => { + todo!() + } + // Not yet fully supported + DataType::Utf8View => { + todo!() + } + // Not yet fully supported + DataType::ListView(..) => { + todo!() + } + // Not yet fully supported + DataType::LargeListView(..) => { + todo!() + } + // f16 currently unstable + DataType::Float16 => { + todo!() + } + // i256 currently unstable + DataType::Decimal256(..) => { + todo!() + } _ => { println!("Unsupported data type"); Ok(vec![serde_json::Value::Null]) @@ -213,128 +325,194 @@ impl Test for RecordBatch { println!("{:?}", serde_json::to_string(&columns).unwrap()); - Ok(json!({})) + serde_json::to_value(&columns).map_err(serde_to_rig_error) } } /// Trait used to "deserialize" an arrow_array::Array as as list of primitive objects. pub trait DeserializePrimitiveArray { - fn to_primitive( - &self, - ) -> Result::Native>, ArrowError>; + /// Downcast arrow Array into a `PrimitiveArray` with items that implement trait `ArrowPrimitiveType`. + /// Return the primitive array values. + fn to_primitive(&self) -> Vec<::Native>; - fn to_primitive_value(&self) -> Result, VectorStoreError> + /// Same as above but convert the resulting array values into serde_json::Value. + fn to_primitive_value(&self) -> Result, serde_json::Error> where ::Native: Serialize; } impl DeserializePrimitiveArray for &Arc { - fn to_primitive( - &self, - ) -> Result::Native>, ArrowError> { - match self.as_any().downcast_ref::>() { - Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), - None => Err(ArrowError::CastError(format!( - "Can't cast array: {self:?} to float array" - ))), - } + fn to_primitive(&self) -> Vec<::Native> { + let primitive_array = self.as_primitive::(); + + (0..primitive_array.len()) + .map(|i| primitive_array.value(i)) + .collect() } - fn to_primitive_value(&self) -> Result, VectorStoreError> + fn to_primitive_value(&self) -> Result, serde_json::Error> where ::Native: Serialize, { self.to_primitive::() - .map_err(arrow_to_rig_error)? .iter() .map(serde_json::to_value) - .collect::, _>>() - .map_err(serde_to_rig_error) + .collect() } } -/// Trait used to "deserialize" an arrow_array::Array as as list of byte objects. +/// Trait used to "deserialize" an arrow_array::Array as as list of str objects. pub trait DeserializeByteArray { - fn to_str(&self) -> Result::Native>, ArrowError>; + /// Downcast arrow Array into a `GenericByteArray` with items that implement trait `ByteArrayType`. + /// Return the generic byte array values. + fn to_str(&self) -> Vec<&::Native>; - fn to_str_value(&self) -> Result, VectorStoreError> + /// Same as above but convert the resulting array values into serde_json::Value. + fn to_str_value(&self) -> Result, serde_json::Error> where ::Native: Serialize; } impl DeserializeByteArray for &Arc { - fn to_str(&self) -> Result::Native>, ArrowError> { - match self.as_any().downcast_ref::>() { - Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), - None => Err(ArrowError::CastError(format!( - "Can't cast array: {self:?} to float array" - ))), - } + fn to_str(&self) -> Vec<&::Native> { + let byte_array = self.as_bytes::(); + (0..byte_array.len()).map(|j| byte_array.value(j)).collect() } - fn to_str_value(&self) -> Result, VectorStoreError> + fn to_str_value(&self) -> Result, serde_json::Error> where ::Native: Serialize, { self.to_str::() - .map_err(arrow_to_rig_error)? .iter() .map(serde_json::to_value) - .collect::, _>>() - .map_err(serde_to_rig_error) + .collect() } } -/// Trait used to "deserialize" an arrow_array::Array as as list of list objects. +/// Trait used to "deserialize" an arrow_array::Array as a list of list objects. trait DeserializeListArray { - fn nested_lists( - &self, - ) -> Result>, ArrowError>; + /// Downcast arrow Array into a `GenericListArray` with items that implement trait `OffsetSizeTrait`. + /// Return the generic list array values. + fn to_list(&self) -> Vec>; } impl DeserializeListArray for &Arc { - fn nested_lists( + fn to_list(&self) -> Vec> { + (0..self.as_list::().len()) + .map(|j| self.as_list::().value(j)) + .collect() + } +} + +/// Trait used to "deserialize" an arrow_array::Array as a list of dict objects. +trait DeserializeDictArray { + /// Downcast arrow Array into a `DictionaryArray` with items that implement trait `ArrowDictionaryKeyType`. + /// Return the dictionary keys and values as a tuple. + fn to_dict( &self, - ) -> Result>, ArrowError> { - match self.as_any().downcast_ref::>() { - Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), - None => Err(ArrowError::CastError(format!( - "Can't cast array: {self:?} to float array" - ))), - } + ) -> ( + Vec<::Native>, + &Arc, + ); + + fn to_dict_values( + &self, + ) -> Result<(Vec, &Arc), serde_json::Error> + where + ::Native: Serialize; +} + +impl DeserializeDictArray for &Arc { + fn to_dict( + &self, + ) -> ( + Vec<::Native>, + &Arc, + ) { + let dict_array = self.as_dictionary::(); + ( + (0..dict_array.keys().len()) + .map(|i| dict_array.keys().value(i)) + .collect(), + dict_array.values(), + ) + } + + fn to_dict_values( + &self, + ) -> Result<(Vec, &Arc), serde_json::Error> + where + ::Native: Serialize, + { + let (k, v) = self.to_dict::(); + + Ok(( + k.iter() + .map(serde_json::to_string) + .collect::, _>>()?, + v, + )) } } -/// Trait used to "deserialize" an arrow_array::Array as as list of list objects. +/// Trait used to "deserialize" an arrow_array::Array as as list of fixed size list objects. trait DeserializeArray { - fn fixed_nested_lists(&self) -> Result>, ArrowError>; + /// Downcast arrow Array into a `FixedSizeListArray`. + /// Return the fixed size list array values. + fn to_fixed_lists(&self) -> Vec>; } impl DeserializeArray for &Arc { - fn fixed_nested_lists(&self) -> Result>, ArrowError> { - match self.as_any().downcast_ref::() { - Some(list_array) => Ok((0..list_array.len()) - .map(|j| list_array.value(j as usize)) - .collect::>()), - None => { - return Err(ArrowError::CastError(format!( - "Can't cast column {self:?} to fixed size list array" - ))); - } + fn to_fixed_lists(&self) -> Vec> { + let list_array = self.as_fixed_size_list(); + + (0..list_array.len()).map(|i| list_array.value(i)).collect() + } +} + +type RunArrayParts = ( + Vec<::Native>, + Arc, +); + +/// Trait used to "deserialize" an arrow_array::Array as a list of list objects. +trait DeserializeRunArray { + /// Downcast arrow Array into a `GenericListArray` with items that implement trait `RunEndIndexType`. + /// Return the generic list array values. + fn to_run_end(&self) -> Result, ArrowError>; +} + +impl DeserializeRunArray for &Arc { + fn to_run_end(&self) -> Result, ArrowError> { + if let Some(run_array) = self.as_any().downcast_ref::>() { + return Ok(( + run_array.run_ends().values().to_vec(), + run_array.values().clone(), + )); } + Err(ArrowError::CastError(format!( + "Can't cast array: {self:?} to list array" + ))) } } trait DeserializeStructArray { - fn nested_lists(&self) -> Vec>; + fn inner_lists(&self) -> Vec>; + + fn num_rows(&self) -> usize; } impl DeserializeStructArray for StructArray { - fn nested_lists(&self) -> Vec> { - (0..self.len()) + fn inner_lists(&self) -> Vec> { + (0..self.num_columns()) .map(|j| self.column(j).clone()) .collect::>() } + + fn num_rows(&self) -> usize { + self.column(0).into_data().len() + } } trait MapOk { @@ -354,16 +532,143 @@ where } } +trait RebuildObject { + fn build_struct(&self, num_rows: usize, col_names: Vec<&str>) -> Vec; + + fn build_map(&self) -> Vec; +} + +impl RebuildObject for Vec> { + fn build_struct(&self, num_rows: usize, col_names: Vec<&str>) -> Vec { + (0..num_rows) + .map(|row_i| { + self.iter() + .enumerate() + .fold(serde_json::Map::new(), |mut acc, (col_i, col)| { + acc.insert(col_names[col_i].to_string(), col[row_i].clone()); + acc + }) + }) + .map(Value::Object) + .collect() + } + + fn build_map(&self) -> Vec { + let keys = &self[0]; + let values = &self[1]; + + keys.iter() + .zip(values) + .map(|(k, v)| { + let mut map = serde_json::Map::new(); + map.insert( + match k { + serde_json::Value::String(s) => s.clone(), + _ => k.to_string(), + }, + v.clone(), + ); + map + }) + .map(Value::Object) + .collect() + } +} + #[cfg(test)] mod tests { use std::sync::Arc; use arrow_array::{ - builder::{FixedSizeListBuilder, ListBuilder, StringBuilder, StructBuilder}, ArrayRef, BinaryArray, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, ListArray, RecordBatch, StringArray, StructArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array + builder::{ + FixedSizeListBuilder, ListBuilder, StringBuilder, StringDictionaryBuilder, + StringRunBuilder, UnionBuilder, + }, + types::{Float64Type, Int16Type, Int32Type, Int8Type}, + ArrayRef, BinaryArray, FixedSizeListArray, Float32Array, Float64Array, GenericListArray, + Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, + MapArray, RecordBatch, StringArray, StructArray, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, }; - use lancedb::arrow::arrow_schema::{DataType, Field}; + use lancedb::arrow::arrow_schema::{DataType, Field, Fields}; + use serde_json::json; + + use crate::utils::deserializer::RecordBatchDeserializer; + + fn fixed_list_actors() -> FixedSizeListArray { + let mut builder = FixedSizeListBuilder::new(StringBuilder::new(), 2); + builder.values().append_value("Johnny Depp"); + builder.values().append_value("Cate Blanchet"); + builder.append(true); + builder.values().append_value("Meryl Streep"); + builder.values().append_value("Scarlett Johansson"); + builder.append(true); + builder.values().append_value("Brad Pitt"); + builder.values().append_value("Natalie Portman"); + builder.append(true); + + builder.finish() + } + + fn name_list() -> GenericListArray { + let mut builder = ListBuilder::new(StringBuilder::new()); + builder.values().append_value("Alice"); + builder.values().append_value("Bob"); + builder.append(true); + builder.values().append_value("Charlie"); + builder.append(true); + builder.values().append_value("David"); + builder.values().append_value("Eve"); + builder.values().append_value("Frank"); + builder.append(true); + builder.finish() + } + + fn nested_list_of_animals() -> GenericListArray { + // [ [ [ "Dog", "Cat" ], ["Mouse"] ], [ [ "Giraffe" ], ["Cow", "Pig"] ], [ [ "Sloth" ], ["Ant", "Monkey"] ] ] + let mut builder = ListBuilder::new(ListBuilder::new(StringBuilder::new())); + builder + .values() + .append_value(vec![Some("Dog"), Some("Cat")]); + builder.values().append_value(vec![Some("Mouse")]); + builder.append(true); + builder.values().append_value(vec![Some("Giraffe")]); + builder + .values() + .append_value(vec![Some("Cow"), Some("Pig")]); + builder.append(true); + builder.values().append_value(vec![Some("Sloth")]); + builder + .values() + .append_value(vec![Some("Ant"), Some("Monkey")]); + builder.append(true); + builder.finish() + } - use crate::utils::deserializer::Test; + fn movie_struct() -> StructArray { + StructArray::from(vec![ + ( + Arc::new(Field::new("name", DataType::Utf8, false)), + Arc::new(StringArray::from(vec![ + "Pulp Fiction", + "The Shawshank Redemption", + "La La Land", + ])) as ArrayRef, + ), + ( + Arc::new(Field::new("year", DataType::UInt32, false)), + Arc::new(UInt32Array::from(vec![1999, 2026, 1745])) as ArrayRef, + ), + ( + Arc::new(Field::new( + "actors", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Utf8, true)), 2), + false, + )), + Arc::new(fixed_list_actors()) as ArrayRef, + ), + ]) + } #[tokio::test] async fn test_primitive_deserialization() { @@ -371,10 +676,8 @@ mod tests { let large_string = Arc::new(LargeStringArray::from_iter_values(vec!["Jerry", "Freddy"])) as ArrayRef; let binary = Arc::new(BinaryArray::from_iter_values(vec![b"hello", b"world"])) as ArrayRef; - let large_binary = Arc::new(LargeBinaryArray::from_iter_values(vec![ - b"The bright sun sets behind the mountains, casting gold", - b"A gentle breeze rustles through the trees at twilight.", - ])) as ArrayRef; + let large_binary = + Arc::new(LargeBinaryArray::from_iter_values(vec![b"abc", b"def"])) as ArrayRef; let float_32 = Arc::new(Float32Array::from_iter_values(vec![0.0, 1.0])) as ArrayRef; let float_64 = Arc::new(Float64Array::from_iter_values(vec![0.0, 1.0])) as ArrayRef; let int_8 = Arc::new(Int8Array::from_iter_values(vec![0, -1])) as ArrayRef; @@ -404,80 +707,161 @@ mod tests { ]) .unwrap(); - let _t = record_batch.deserialize().unwrap(); - - assert!(false) + assert_eq!( + record_batch.deserialize().unwrap(), + json!([ + [0.0, 1.0], + [0.0, 1.0], + [0, -1], + [0, 1], + [0, -1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + ["Marty", "Tony"], + ["Jerry", "Freddy"], + [[97, 98, 99], [100, 101, 102]], + [[104, 101, 108, 108, 111], [119, 111, 114, 108, 100]] + ]) + ) } #[tokio::test] - async fn test_list_recursion() { - let mut builder = FixedSizeListBuilder::new(StringBuilder::new(), 3); - builder.values().append_value("Hi"); - builder.values().append_value("Hey"); - builder.values().append_value("What's up"); - builder.append(true); - builder.values().append_value("Bye"); - builder.values().append_value("Seeya"); - builder.values().append_value("Later"); - builder.append(true); + async fn test_dictionary_deserialization() { + let dictionary_values = StringArray::from(vec![None, Some("abc"), Some("def")]); - let record_batch = RecordBatch::try_from_iter(vec![( - "salutations", - Arc::new(builder.finish()) as ArrayRef, - )]) - .unwrap(); + let mut builder = + StringDictionaryBuilder::::new_with_dictionary(3, &dictionary_values) + .unwrap(); + builder.append("def").unwrap(); + builder.append_null(); + builder.append("abc").unwrap(); - let _t = record_batch.deserialize().unwrap(); + let dictionary_array = builder.finish(); + + let record_batch = + RecordBatch::try_from_iter(vec![("some_dict", Arc::new(dictionary_array) as ArrayRef)]) + .unwrap(); - assert!(false) + assert_eq!( + record_batch.deserialize().unwrap(), + json!([ + [ + { + "2": "" + }, + { + "0": "abc" + }, + { + "1": "def" + } + ] + ]) + ) } #[tokio::test] - async fn test_list_recursion_2() { - let mut builder = ListBuilder::new(ListBuilder::new(StringBuilder::new())); - builder - .values() - .append_value(vec![Some("Dog"), Some("Cat")]); - builder - .values() - .append_value(vec![Some("Mouse"), Some("Bird")]); - builder.append(true); - builder - .values() - .append_value(vec![Some("Giraffe"), Some("Mammoth")]); - builder - .values() - .append_value(vec![Some("Cow"), Some("Pig")]); + async fn test_union_deserialization() { + let mut builder = UnionBuilder::new_dense(); + builder.append::("type_a", 1).unwrap(); + builder.append::("type_b", 3.0).unwrap(); + builder.append::("type_a", 4).unwrap(); + let union = builder.build().unwrap(); let record_batch = - RecordBatch::try_from_iter(vec![("animals", Arc::new(builder.finish()) as ArrayRef)]) - .unwrap(); - - let _t = record_batch.deserialize().unwrap(); + RecordBatch::try_from_iter(vec![("some_dict", Arc::new(union) as ArrayRef)]).unwrap(); - assert!(false) + assert_eq!( + record_batch.deserialize().unwrap(), + json!([[[1], [3.0], [4]]]) + ) } #[tokio::test] - async fn test_struct() { - let id_values = StringArray::from(vec!["id1", "id2", "id3"]); + async fn test_run_end_deserialization() { + let mut builder = StringRunBuilder::::new(); - let age_values = Float32Array::from(vec![25.0, 30.5, 22.1]); + // The builder builds the dictionary value by value + builder.append_value("abc"); + builder.append_null(); + builder.extend([Some("def"), Some("def"), Some("abc")]); + let array = builder.finish(); - let mut names_builder = ListBuilder::new(StringBuilder::new()); - names_builder.values().append_value("Alice"); - names_builder.values().append_value("Bob"); - names_builder.append(true); - names_builder.values().append_value("Charlie"); - names_builder.append(true); - names_builder.values().append_value("David"); - names_builder.values().append_value("Eve"); - names_builder.values().append_value("Frank"); - names_builder.append(true); + let record_batch = + RecordBatch::try_from_iter(vec![("some_dict", Arc::new(array) as ArrayRef)]).unwrap(); + + assert_eq!( + record_batch.deserialize().unwrap(), + json!([[ + ["abc"], + ["", ""], + ["def", "def", "def", "def"], + ["abc", "abc", "abc", "abc", "abc"] + ]]) + ) + } + + #[tokio::test] + async fn test_map_deserialization() { + let record_batch = RecordBatch::try_from_iter(vec![( + "map_col", + Arc::new( + MapArray::new_from_strings( + vec!["tarentino", "darabont", "chazelle"].into_iter(), + &movie_struct(), + &[0, 1, 2], + ) + .unwrap(), + ) as ArrayRef, + )]) + .unwrap(); - let names_array = names_builder.finish(); + assert_eq!( + record_batch.deserialize().unwrap(), + json!([ + [ + { + "tarentino": { + "name": "Pulp Fiction", + "year": 1999, + "actors": [ + "Johnny Depp", + "Cate Blanchet" + ] + } + }, + { + "darabont": { + "name": "The Shawshank Redemption", + "year": 2026, + "actors": [ + "Meryl Streep", + "Scarlett Johansson" + ] + } + }, + { + "chazelle": { + "name": "La La Land", + "year": 1745, + "actors": [ + "Brad Pitt", + "Natalie Portman" + ] + } + } + ] + ]) + ) + } - // Step 4: Combine into a StructArray + #[tokio::test] + async fn test_recursion() { + let id_values = StringArray::from(vec!["id1", "id2", "id3"]); + let age_values = Float32Array::from(vec![25.0, 30.5, 22.1]); let struct_array = StructArray::from(vec![ ( Arc::new(Field::new("id", DataType::Utf8, false)), @@ -493,7 +877,38 @@ mod tests { DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), false, )), - Arc::new(names_array) as ArrayRef, + Arc::new(name_list()) as ArrayRef, + ), + ( + Arc::new(Field::new( + "favorite_animals", + DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + ))), + false, + )), + Arc::new(nested_list_of_animals()) as ArrayRef, + ), + ( + Arc::new(Field::new( + "favorite_movie", + DataType::Struct(Fields::from_iter(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("year", DataType::UInt32, false), + Field::new( + "actors", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Utf8, true)), + 2, + ), + false, + ), + ])), + false, + )), + Arc::new(movie_struct()) as ArrayRef, ), ]); @@ -501,8 +916,87 @@ mod tests { RecordBatch::try_from_iter(vec![("employees", Arc::new(struct_array) as ArrayRef)]) .unwrap(); - let _t = record_batch.deserialize().unwrap(); - - assert!(false) + assert_eq!( + record_batch.deserialize().unwrap(), + json!([ + [ + { + "id": "id1", + "age": 25.0, + "names": [ + "Alice", + "Bob" + ], + "favorite_animals": [ + [ + "Dog", + "Cat" + ], + [ + "Mouse" + ] + ], + "favorite_movie": { + "name": "Pulp Fiction", + "year": 1999, + "actors": [ + "Johnny Depp", + "Cate Blanchet" + ] + } + }, + { + "id": "id2", + "age": 30.5, + "names": [ + "Charlie" + ], + "favorite_animals": [ + [ + "Giraffe" + ], + [ + "Cow", + "Pig" + ] + ], + "favorite_movie": { + "name": "The Shawshank Redemption", + "year": 2026, + "actors": [ + "Meryl Streep", + "Scarlett Johansson" + ] + } + }, + { + "id": "id3", + "age": 22.100000381469727, + "names": [ + "David", + "Eve", + "Frank" + ], + "favorite_animals": [ + [ + "Sloth" + ], + [ + "Ant", + "Monkey" + ] + ], + "favorite_movie": { + "name": "La La Land", + "year": 1745, + "actors": [ + "Brad Pitt", + "Natalie Portman" + ] + } + } + ] + ]) + ) } }