From 6fede36516600378e8adbb4551648df627f3df78 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 25 Sep 2024 17:58:42 -0400 Subject: [PATCH] feat: implement deserialization for any recordbatch returned from lanceDB --- rig-lancedb/src/utils/deserializer.rs | 508 ++++++++++++++++++++++++++ rig-lancedb/src/utils/mod.rs | 1 + 2 files changed, 509 insertions(+) create mode 100644 rig-lancedb/src/utils/deserializer.rs diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs new file mode 100644 index 00000000..bbe915b5 --- /dev/null +++ b/rig-lancedb/src/utils/deserializer.rs @@ -0,0 +1,508 @@ +use std::sync::Arc; + +use arrow_array::{ + types::{ + BinaryType, ByteArrayType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, + DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, + LargeBinaryType, LargeUtf8Type, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, Utf8Type, + }, + Array, ArrowPrimitiveType, FixedSizeBinaryArray, FixedSizeListArray, GenericByteArray, + GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatch, StructArray, +}; +use lancedb::arrow::arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; +use rig::vector_store::VectorStoreError; +use serde::Serialize; +use serde_json::{json, Value}; + +use crate::serde_to_rig_error; + +fn arrow_to_rig_error(e: ArrowError) -> VectorStoreError { + VectorStoreError::DatastoreError(Box::new(e)) +} + +trait Test { + fn deserialize(&self) -> Result; +} + +impl Test for RecordBatch { + fn deserialize(&self) -> Result { + fn type_matcher(column: &Arc) -> Result, VectorStoreError> { + match column.data_type() { + DataType::Null => Ok(vec![serde_json::Value::Null]), + // f16 does not implement serde_json::Deserialize. Need to cast to f32. + DataType::Float16 => column + .to_primitive::() + .map_err(arrow_to_rig_error)? + .iter() + .map(|float_16| serde_json::to_value(float_16.to_f32())) + .collect::, _>>() + .map_err(serde_to_rig_error), + DataType::Float32 => column.to_primitive_value::(), + DataType::Float64 => column.to_primitive_value::(), + DataType::Int8 => column.to_primitive_value::(), + DataType::Int16 => column.to_primitive_value::(), + DataType::Int32 => column.to_primitive_value::(), + DataType::Int64 => column.to_primitive_value::(), + DataType::UInt8 => column.to_primitive_value::(), + DataType::UInt16 => column.to_primitive_value::(), + DataType::UInt32 => column.to_primitive_value::(), + DataType::UInt64 => column.to_primitive_value::(), + DataType::Date32 => column.to_primitive_value::(), + DataType::Date64 => column.to_primitive_value::(), + DataType::Decimal128(..) => column.to_primitive_value::(), + // i256 does not implement serde_json::Deserialize. Need to cast to i128. + DataType::Decimal256(..) => column + .to_primitive::() + .map_err(arrow_to_rig_error)? + .iter() + .map(|dec_256| serde_json::to_value(dec_256.as_i128())) + .collect::, _>>() + .map_err(serde_to_rig_error), + DataType::Time32(TimeUnit::Second) => { + column.to_primitive_value::() + } + DataType::Time32(TimeUnit::Millisecond) => { + column.to_primitive_value::() + } + DataType::Time64(TimeUnit::Microsecond) => { + column.to_primitive_value::() + } + DataType::Time64(TimeUnit::Nanosecond) => { + column.to_primitive_value::() + } + DataType::Timestamp(TimeUnit::Microsecond, ..) => { + column.to_primitive_value::() + } + DataType::Timestamp(TimeUnit::Millisecond, ..) => { + column.to_primitive_value::() + } + DataType::Timestamp(TimeUnit::Second, ..) => { + column.to_primitive_value::() + } + DataType::Timestamp(TimeUnit::Nanosecond, ..) => { + column.to_primitive_value::() + } + DataType::Duration(TimeUnit::Microsecond) => { + column.to_primitive_value::() + } + DataType::Duration(TimeUnit::Millisecond) => { + column.to_primitive_value::() + } + DataType::Duration(TimeUnit::Nanosecond) => { + column.to_primitive_value::() + } + DataType::Duration(TimeUnit::Second) => { + column.to_primitive_value::() + } + DataType::Interval(IntervalUnit::DayTime) => Ok(column + .to_primitive::() + .map_err(arrow_to_rig_error)? + .iter() + .map(|interval| { + json!({ + "days": interval.days, + "milliseconds": interval.milliseconds, + }) + }) + .collect()), + DataType::Interval(IntervalUnit::MonthDayNano) => Ok(column + .to_primitive::() + .map_err(arrow_to_rig_error)? + .iter() + .map(|interval| { + json!({ + "months": interval.months, + "days": interval.days, + "nanoseconds": interval.nanoseconds, + }) + }) + .collect()), + DataType::Interval(IntervalUnit::YearMonth) => { + column.to_primitive_value::() + } + DataType::Utf8 | DataType::Utf8View => column.to_str_value::(), + DataType::LargeUtf8 => column.to_str_value::(), + DataType::Binary => column.to_str_value::(), + DataType::LargeBinary => column.to_str_value::(), + DataType::FixedSizeBinary(n) => { + match column.as_any().downcast_ref::() { + Some(list_array) => (0..*n) + .map(|j| serde_json::to_value(list_array.value(j as usize))) + .collect::, _>>() + .map_err(serde_to_rig_error), + None => Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!( + "Can't cast column {column:?} to fixed size list array" + )), + ))), + } + } + DataType::FixedSizeList(..) => column + .fixed_nested_lists() + .map_err(arrow_to_rig_error)? + .iter() + .map(|nested_list| type_matcher(nested_list)) + .map_ok(), + DataType::List(..) | DataType::ListView(..) => column + .nested_lists::() + .map_err(arrow_to_rig_error)? + .iter() + .map(|nested_list| type_matcher(nested_list)) + .map_ok(), + DataType::LargeList(..) | DataType::LargeListView(..) => column + .nested_lists::() + .map_err(arrow_to_rig_error)? + .iter() + .map(|nested_list| type_matcher(nested_list)) + .map_ok(), + DataType::Struct(..) => match column.as_any().downcast_ref::() { + Some(struct_array) => struct_array + .nested_lists() + .iter() + .map(|nested_list| type_matcher(nested_list)) + .map_ok(), + None => Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!( + "Can't cast array: {column:?} to struct array" + )), + ))), + }, + // DataType::Map(..) => { + // let item = match column.as_any().downcast_ref::() { + // Some(map_array) => map_array + // .entries() + // .nested_lists() + // .iter() + // .map(|nested_list| type_matcher(nested_list, nested_list.data_type())) + // .collect::, _>>(), + // None => Err(VectorStoreError::DatastoreError(Box::new( + // ArrowError::CastError(format!( + // "Can't cast array: {column:?} to map array" + // )), + // ))), + // }?; + // } + // DataType::Dictionary(key_data_type, value_data_type) => { + // let item = match column.as_any().downcast_ref::() { + // Some(map_array) => { + // let keys = &Arc::new(map_array.keys()); + // type_matcher(keys, keys.data_type()) + // } + // None => Err(ArrowError::CastError(format!( + // "Can't cast array: {column:?} to map array" + // ))), + // }?; + // }, + _ => { + println!("Unsupported data type"); + Ok(vec![serde_json::Value::Null]) + } + } + } + + let columns = self + .columns() + .iter() + .map(type_matcher) + .collect::, _>>()?; + + println!("{:?}", serde_json::to_string(&columns).unwrap()); + + Ok(json!({})) + } +} + +/// Trait used to "deserialize" an arrow_array::Array as as list of primitive objects. +pub trait DeserializePrimitiveArray { + fn to_primitive( + &self, + ) -> Result::Native>, ArrowError>; + + fn to_primitive_value(&self) -> Result, VectorStoreError> + where + ::Native: Serialize; +} + +impl DeserializePrimitiveArray for &Arc { + fn to_primitive( + &self, + ) -> Result::Native>, ArrowError> { + match self.as_any().downcast_ref::>() { + Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), + None => Err(ArrowError::CastError(format!( + "Can't cast array: {self:?} to float array" + ))), + } + } + + fn to_primitive_value(&self) -> Result, VectorStoreError> + where + ::Native: Serialize, + { + self.to_primitive::() + .map_err(arrow_to_rig_error)? + .iter() + .map(serde_json::to_value) + .collect::, _>>() + .map_err(serde_to_rig_error) + } +} + +/// Trait used to "deserialize" an arrow_array::Array as as list of byte objects. +pub trait DeserializeByteArray { + fn to_str(&self) -> Result::Native>, ArrowError>; + + fn to_str_value(&self) -> Result, VectorStoreError> + where + ::Native: Serialize; +} + +impl DeserializeByteArray for &Arc { + fn to_str(&self) -> Result::Native>, ArrowError> { + match self.as_any().downcast_ref::>() { + Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), + None => Err(ArrowError::CastError(format!( + "Can't cast array: {self:?} to float array" + ))), + } + } + + fn to_str_value(&self) -> Result, VectorStoreError> + where + ::Native: Serialize, + { + self.to_str::() + .map_err(arrow_to_rig_error)? + .iter() + .map(serde_json::to_value) + .collect::, _>>() + .map_err(serde_to_rig_error) + } +} + +/// Trait used to "deserialize" an arrow_array::Array as as list of list objects. +trait DeserializeListArray { + fn nested_lists( + &self, + ) -> Result>, ArrowError>; +} + +impl DeserializeListArray for &Arc { + fn nested_lists( + &self, + ) -> Result>, ArrowError> { + match self.as_any().downcast_ref::>() { + Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), + None => Err(ArrowError::CastError(format!( + "Can't cast array: {self:?} to float array" + ))), + } + } +} + +/// Trait used to "deserialize" an arrow_array::Array as as list of list objects. +trait DeserializeArray { + fn fixed_nested_lists(&self) -> Result>, ArrowError>; +} + +impl DeserializeArray for &Arc { + fn fixed_nested_lists(&self) -> Result>, ArrowError> { + match self.as_any().downcast_ref::() { + Some(list_array) => Ok((0..list_array.len()) + .map(|j| list_array.value(j as usize)) + .collect::>()), + None => { + return Err(ArrowError::CastError(format!( + "Can't cast column {self:?} to fixed size list array" + ))); + } + } + } +} + +trait DeserializeStructArray { + fn nested_lists(&self) -> Vec>; +} + +impl DeserializeStructArray for StructArray { + fn nested_lists(&self) -> Vec> { + (0..self.len()) + .map(|j| self.column(j).clone()) + .collect::>() + } +} + +trait MapOk { + fn map_ok(self) -> Result, VectorStoreError>; +} + +impl MapOk for I +where + I: Iterator, VectorStoreError>>, +{ + fn map_ok(self) -> Result, VectorStoreError> { + self.map(|maybe_list| match maybe_list { + Ok(list) => serde_json::to_value(list).map_err(serde_to_rig_error), + Err(e) => Err(e), + }) + .collect::, _>>() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{ + builder::{FixedSizeListBuilder, ListBuilder, StringBuilder, StructBuilder}, ArrayRef, BinaryArray, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, ListArray, RecordBatch, StringArray, StructArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array + }; + use lancedb::arrow::arrow_schema::{DataType, Field}; + + use crate::utils::deserializer::Test; + + #[tokio::test] + async fn test_primitive_deserialization() { + let string = Arc::new(StringArray::from_iter_values(vec!["Marty", "Tony"])) as ArrayRef; + let large_string = + Arc::new(LargeStringArray::from_iter_values(vec!["Jerry", "Freddy"])) as ArrayRef; + let binary = Arc::new(BinaryArray::from_iter_values(vec![b"hello", b"world"])) as ArrayRef; + let large_binary = Arc::new(LargeBinaryArray::from_iter_values(vec![ + b"The bright sun sets behind the mountains, casting gold", + b"A gentle breeze rustles through the trees at twilight.", + ])) as ArrayRef; + let float_32 = Arc::new(Float32Array::from_iter_values(vec![0.0, 1.0])) as ArrayRef; + let float_64 = Arc::new(Float64Array::from_iter_values(vec![0.0, 1.0])) as ArrayRef; + let int_8 = Arc::new(Int8Array::from_iter_values(vec![0, -1])) as ArrayRef; + let int_16 = Arc::new(Int16Array::from_iter_values(vec![-0, 1])) as ArrayRef; + let int_32 = Arc::new(Int32Array::from_iter_values(vec![0, -1])) as ArrayRef; + let int_64 = Arc::new(Int64Array::from_iter_values(vec![-0, 1])) as ArrayRef; + let uint_8 = Arc::new(UInt8Array::from_iter_values(vec![0, 1])) as ArrayRef; + let uint_16 = Arc::new(UInt16Array::from_iter_values(vec![0, 1])) as ArrayRef; + let uint_32 = Arc::new(UInt32Array::from_iter_values(vec![0, 1])) as ArrayRef; + let uint_64 = Arc::new(UInt64Array::from_iter_values(vec![0, 1])) as ArrayRef; + + let record_batch = RecordBatch::try_from_iter(vec![ + ("float_32", float_32), + ("float_64", float_64), + ("int_8", int_8), + ("int_16", int_16), + ("int_32", int_32), + ("int_64", int_64), + ("uint_8", uint_8), + ("uint_16", uint_16), + ("uint_32", uint_32), + ("uint_64", uint_64), + ("string", string), + ("large_string", large_string), + ("large_binary", large_binary), + ("binary", binary), + ]) + .unwrap(); + + let _t = record_batch.deserialize().unwrap(); + + assert!(false) + } + + #[tokio::test] + async fn test_list_recursion() { + let mut builder = FixedSizeListBuilder::new(StringBuilder::new(), 3); + builder.values().append_value("Hi"); + builder.values().append_value("Hey"); + builder.values().append_value("What's up"); + builder.append(true); + builder.values().append_value("Bye"); + builder.values().append_value("Seeya"); + builder.values().append_value("Later"); + builder.append(true); + + let record_batch = RecordBatch::try_from_iter(vec![( + "salutations", + Arc::new(builder.finish()) as ArrayRef, + )]) + .unwrap(); + + let _t = record_batch.deserialize().unwrap(); + + assert!(false) + } + + #[tokio::test] + async fn test_list_recursion_2() { + let mut builder = ListBuilder::new(ListBuilder::new(StringBuilder::new())); + builder + .values() + .append_value(vec![Some("Dog"), Some("Cat")]); + builder + .values() + .append_value(vec![Some("Mouse"), Some("Bird")]); + builder.append(true); + builder + .values() + .append_value(vec![Some("Giraffe"), Some("Mammoth")]); + builder + .values() + .append_value(vec![Some("Cow"), Some("Pig")]); + + let record_batch = + RecordBatch::try_from_iter(vec![("animals", Arc::new(builder.finish()) as ArrayRef)]) + .unwrap(); + + let _t = record_batch.deserialize().unwrap(); + + assert!(false) + } + + #[tokio::test] + async fn test_struct() { + let id_values = StringArray::from(vec!["id1", "id2", "id3"]); + + let age_values = Float32Array::from(vec![25.0, 30.5, 22.1]); + + let mut names_builder = ListBuilder::new(StringBuilder::new()); + names_builder.values().append_value("Alice"); + names_builder.values().append_value("Bob"); + names_builder.append(true); + names_builder.values().append_value("Charlie"); + names_builder.append(true); + names_builder.values().append_value("David"); + names_builder.values().append_value("Eve"); + names_builder.values().append_value("Frank"); + names_builder.append(true); + + let names_array = names_builder.finish(); + + // Step 4: Combine into a StructArray + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("id", DataType::Utf8, false)), + Arc::new(id_values) as ArrayRef, + ), + ( + Arc::new(Field::new("age", DataType::Float32, false)), + Arc::new(age_values) as ArrayRef, + ), + ( + Arc::new(Field::new( + "names", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + false, + )), + Arc::new(names_array) as ArrayRef, + ), + ]); + + let record_batch = + RecordBatch::try_from_iter(vec![("employees", Arc::new(struct_array) as ArrayRef)]) + .unwrap(); + + let _t = record_batch.deserialize().unwrap(); + + assert!(false) + } +} diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index bf8874e2..bb9d2599 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -1,3 +1,4 @@ +pub mod deserializer; use std::sync::Arc; use arrow_array::{