diff --git a/arrow-array/src/array/list_view_array.rs b/arrow-array/src/array/list_view_array.rs index 195ac7e11611..6118607bcbbf 100644 --- a/arrow-array/src/array/list_view_array.rs +++ b/arrow-array/src/array/list_view_array.rs @@ -895,8 +895,8 @@ mod tests { .build() .unwrap(), ); - assert_eq!(string.value_offsets(), &[]); - assert_eq!(string.value_sizes(), &[]); + assert_eq!(string.value_offsets(), &[] as &[i32; 0]); + assert_eq!(string.value_sizes(), &[] as &[i32; 0]); let string = LargeListViewArray::from( ArrayData::builder(DataType::LargeListView(f)) @@ -906,8 +906,8 @@ mod tests { .unwrap(), ); assert_eq!(string.len(), 0); - assert_eq!(string.value_offsets(), &[]); - assert_eq!(string.value_sizes(), &[]); + assert_eq!(string.value_offsets(), &[] as &[i64; 0]); + assert_eq!(string.value_sizes(), &[] as &[i64; 0]); } #[test] diff --git a/arrow-schema/Cargo.toml b/arrow-schema/Cargo.toml index d1bcf046b7ca..ffea42db6653 100644 --- a/arrow-schema/Cargo.toml +++ b/arrow-schema/Cargo.toml @@ -34,21 +34,27 @@ path = "src/lib.rs" bench = false [dependencies] -serde = { version = "1.0", default-features = false, features = ["derive", "std", "rc"], optional = true } +serde = { version = "1.0", default-features = false, features = [ + "derive", + "std", + "rc", +], optional = true } bitflags = { version = "2.0.0", default-features = false, optional = true } +serde_json = { version = "1.0", optional = true } [features] +canonical_extension_types = ["dep:serde", "dep:serde_json"] # Enable ffi support ffi = ["bitflags"] +serde = ["dep:serde"] [package.metadata.docs.rs] features = ["ffi"] [dev-dependencies] -serde_json = "1.0" bincode = { version = "1.3.3", default-features = false } criterion = { version = "0.5", default-features = false } [[bench]] name = "ffi" -harness = false \ No newline at end of file +harness = false diff --git a/arrow-schema/src/extension/canonical/bool8.rs b/arrow-schema/src/extension/canonical/bool8.rs new file mode 100644 index 000000000000..3f6c50cb3e5e --- /dev/null +++ b/arrow-schema/src/extension/canonical/bool8.rs @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! 8-bit Boolean +//! +//! + +use crate::{extension::ExtensionType, ArrowError, DataType}; + +/// The extension type for `8-bit Boolean`. +/// +/// Extension name: `arrow.bool8`. +/// +/// The storage type of the extension is `Int8` where: +/// - false is denoted by the value 0. +/// - true can be specified using any non-zero value. Preferably 1. +/// +/// +#[derive(Debug, Default, Clone, Copy, PartialEq)] +pub struct Bool8; + +impl ExtensionType for Bool8 { + const NAME: &'static str = "arrow.bool8"; + + type Metadata = &'static str; + + fn metadata(&self) -> &Self::Metadata { + &"" + } + + fn serialize_metadata(&self) -> Option { + Some(String::default()) + } + + fn deserialize_metadata(metadata: Option<&str>) -> Result { + if metadata.map_or(false, str::is_empty) { + Ok("") + } else { + Err(ArrowError::InvalidArgumentError( + "Bool8 extension type expects an empty string as metadata".to_owned(), + )) + } + } + + fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { + match data_type { + DataType::Int8 => Ok(()), + data_type => Err(ArrowError::InvalidArgumentError(format!( + "Bool8 data type mismatch, expected Int8, found {data_type}" + ))), + } + } + + fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result { + Self.supports_data_type(data_type).map(|_| Self) + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "canonical_extension_types")] + use crate::extension::CanonicalExtensionType; + use crate::{ + extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, + Field, + }; + + use super::*; + + #[test] + fn valid() -> Result<(), ArrowError> { + let mut field = Field::new("", DataType::Int8, false); + field.try_with_extension_type(Bool8)?; + field.try_extension_type::()?; + #[cfg(feature = "canonical_extension_types")] + assert_eq!( + field.try_canonical_extension_type()?, + CanonicalExtensionType::Bool8(Bool8) + ); + + Ok(()) + } + + #[test] + #[should_panic(expected = "Field extension type name missing")] + fn missing_name() { + let field = Field::new("", DataType::Int8, false).with_metadata( + [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "".to_owned())] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic(expected = "expected Int8, found Boolean")] + fn invalid_type() { + Field::new("", DataType::Boolean, false).with_extension_type(Bool8); + } + + #[test] + #[should_panic(expected = "Bool8 extension type expects an empty string as metadata")] + fn missing_metadata() { + let field = Field::new("", DataType::Int8, false).with_metadata( + [(EXTENSION_TYPE_NAME_KEY.to_owned(), Bool8::NAME.to_owned())] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic(expected = "Bool8 extension type expects an empty string as metadata")] + fn invalid_metadata() { + let field = Field::new("", DataType::Int8, false).with_metadata( + [ + (EXTENSION_TYPE_NAME_KEY.to_owned(), Bool8::NAME.to_owned()), + ( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + "non-empty".to_owned(), + ), + ] + .into_iter() + .collect(), + ); + field.extension_type::(); + } +} diff --git a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs new file mode 100644 index 000000000000..6fe94fba78aa --- /dev/null +++ b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs @@ -0,0 +1,443 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! FixedShapeTensor +//! +//! + +use serde::{Deserialize, Serialize}; + +use crate::{extension::ExtensionType, ArrowError, DataType}; + +/// The extension type for fixed shape tensor. +/// +/// Extension name: `arrow.fixed_shape_tensor`. +/// +/// The storage type of the extension: `FixedSizeList` where: +/// - `value_type` is the data type of individual tensor elements. +/// - `list_size` is the product of all the elements in tensor shape. +/// +/// Extension type parameters: +/// - `value_type`: the Arrow data type of individual tensor elements. +/// - `shape`: the physical shape of the contained tensors as an array. +/// +/// Optional parameters describing the logical layout: +/// - `dim_names`: explicit names to tensor dimensions as an array. The +/// length of it should be equal to the shape length and equal to the +/// number of dimensions. +/// `dim_names` can be used if the dimensions have +/// well-known names and they map to the physical layout (row-major). +/// - `permutation`: indices of the desired ordering of the original +/// dimensions, defined as an array. +/// The indices contain a permutation of the values `[0, 1, .., N-1]` +/// where `N` is the number of dimensions. The permutation indicates +/// which dimension of the logical layout corresponds to which dimension +/// of the physical tensor (the i-th dimension of the logical view +/// corresponds to the dimension with number `permutations[i]` of the +/// physical tensor). +/// Permutation can be useful in case the logical order of the tensor is +/// a permutation of the physical order (row-major). +/// When logical and physical layout are equal, the permutation will +/// always be `([0, 1, .., N-1])` and can therefore be left out. +/// +/// Description of the serialization: +/// The metadata must be a valid JSON object including shape of the +/// contained tensors as an array with key `shape` plus optional +/// dimension names with keys `dim_names` and ordering of the +/// dimensions with key `permutation`. +/// Example: `{ "shape": [2, 5]}` +/// Example with `dim_names` metadata for NCHW ordered data: +/// `{ "shape": [100, 200, 500], "dim_names": ["C", "H", "W"]}` +/// Example of permuted 3-dimensional tensor: +/// `{ "shape": [100, 200, 500], "permutation": [2, 0, 1]}` +/// +/// This is the physical layout shape and the shape of the logical layout +/// would in this case be `[500, 100, 200]`. +/// +/// +#[derive(Debug, Clone, PartialEq)] +pub struct FixedShapeTensor { + /// The data type of individual tensor elements. + value_type: DataType, + + /// The metadata of this extension type. + metadata: FixedShapeTensorMetadata, +} + +impl FixedShapeTensor { + /// Returns a new fixed shape tensor extension type. + /// + /// # Error + /// + /// Return an error if the provided dimension names or permutations are + /// invalid. + pub fn try_new( + value_type: DataType, + shape: impl IntoIterator, + dimension_names: Option>, + permutations: Option>, + ) -> Result { + // TODO: are all data types are suitable as value type? + FixedShapeTensorMetadata::try_new(shape, dimension_names, permutations).map(|metadata| { + Self { + value_type, + metadata, + } + }) + } + + /// Returns the value type of the individual tensor elements. + pub fn value_type(&self) -> &DataType { + &self.value_type + } + + /// Returns the product of all the elements in tensor shape. + pub fn list_size(&self) -> usize { + self.metadata.list_size() + } + + /// Returns the number of dimensions in this fixed shape tensor. + pub fn dimensions(&self) -> usize { + self.metadata.dimensions() + } + + /// Returns the names of the dimensions in this fixed shape tensor, if + /// set. + pub fn dimension_names(&self) -> Option<&[String]> { + self.metadata.dimension_names() + } + + /// Returns the indices of the desired ordering of the original + /// dimensions, if set. + pub fn permutations(&self) -> Option<&[usize]> { + self.metadata.permutations() + } +} + +/// Extension type metadata for [`FixedShapeTensor`]. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct FixedShapeTensorMetadata { + /// The physical shape of the contained tensors. + shape: Vec, + + /// Explicit names to tensor dimensions. + dim_names: Option>, + + /// Indices of the desired ordering of the original dimensions. + permutations: Option>, +} + +impl FixedShapeTensorMetadata { + /// Returns metadata for a fixed shape tensor extension type. + /// + /// # Error + /// + /// Return an error if the provided dimension names or permutations are + /// invalid. + pub fn try_new( + shape: impl IntoIterator, + dimension_names: Option>, + permutations: Option>, + ) -> Result { + let shape = shape.into_iter().collect::>(); + let dimensions = shape.len(); + + let dim_names = dimension_names.map(|dimension_names| { + if dimension_names.len() != dimensions { + Err(ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len() + ))) + } else { + Ok(dimension_names) + } + }).transpose()?; + + let permutations = permutations + .map(|permutations| { + if permutations.len() != dimensions { + Err(ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor permutations size mismatch, expected {dimensions}, found {}", + permutations.len() + ))) + } else { + let mut sorted_permutations = permutations.clone(); + sorted_permutations.sort_unstable(); + if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) { + Err(ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}" + ))) + } else { + Ok(permutations) + } + } + }) + .transpose()?; + + Ok(Self { + shape, + dim_names, + permutations, + }) + } + + /// Returns the product of all the elements in tensor shape. + pub fn list_size(&self) -> usize { + self.shape.iter().product() + } + + /// Returns the number of dimensions in this fixed shape tensor. + pub fn dimensions(&self) -> usize { + self.shape.len() + } + + /// Returns the names of the dimensions in this fixed shape tensor, if + /// set. + pub fn dimension_names(&self) -> Option<&[String]> { + self.dim_names.as_ref().map(AsRef::as_ref) + } + + /// Returns the indices of the desired ordering of the original + /// dimensions, if set. + pub fn permutations(&self) -> Option<&[usize]> { + self.permutations.as_ref().map(AsRef::as_ref) + } +} + +impl ExtensionType for FixedShapeTensor { + const NAME: &'static str = "arrow.fixed_shape_tensor"; + + type Metadata = FixedShapeTensorMetadata; + + fn metadata(&self) -> &Self::Metadata { + &self.metadata + } + + fn serialize_metadata(&self) -> Option { + Some(serde_json::to_string(&self.metadata).expect("metadata serialization")) + } + + fn deserialize_metadata(metadata: Option<&str>) -> Result { + metadata.map_or_else( + || { + Err(ArrowError::InvalidArgumentError( + "FixedShapeTensor extension types requires metadata".to_owned(), + )) + }, + |value| { + serde_json::from_str(value).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor metadata deserialization failed: {e}" + )) + }) + }, + ) + } + + fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { + let expected = DataType::new_fixed_size_list( + self.value_type.clone(), + i32::try_from(self.list_size()).expect("overflow"), + false, + ); + data_type + .equals_datatype(&expected) + .then_some(()) + .ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor data type mismatch, expected {expected}, found {data_type}" + )) + }) + } + + fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result { + match data_type { + DataType::FixedSizeList(field, list_size) if !field.is_nullable() => { + // Make sure the metadata is valid. + let metadata = FixedShapeTensorMetadata::try_new( + metadata.shape, + metadata.dim_names, + metadata.permutations, + )?; + // Make sure it is compatible with this data type. + let expected_size = i32::try_from(metadata.list_size()).expect("overflow"); + if *list_size != expected_size { + Err(ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor list size mismatch, expected {expected_size} (metadata), found {list_size} (data type)" + ))) + } else { + Ok(Self { + value_type: field.data_type().clone(), + metadata, + }) + } + } + data_type => Err(ArrowError::InvalidArgumentError(format!( + "FixedShapeTensor data type mismatch, expected FixedSizeList with non-nullable field, found {data_type}" + ))), + } + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "canonical_extension_types")] + use crate::extension::CanonicalExtensionType; + use crate::{ + extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, + Field, + }; + + use super::*; + + #[test] + fn valid() -> Result<(), ArrowError> { + let fixed_shape_tensor = FixedShapeTensor::try_new( + DataType::Float32, + [100, 200, 500], + Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]), + Some(vec![2, 0, 1]), + )?; + let mut field = Field::new_fixed_size_list( + "", + Field::new("", DataType::Float32, false), + i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"), + false, + ); + field.try_with_extension_type(fixed_shape_tensor.clone())?; + assert_eq!( + field.try_extension_type::()?, + fixed_shape_tensor + ); + #[cfg(feature = "canonical_extension_types")] + assert_eq!( + field.try_canonical_extension_type()?, + CanonicalExtensionType::FixedShapeTensor(fixed_shape_tensor) + ); + Ok(()) + } + + #[test] + #[should_panic(expected = "Field extension type name missing")] + fn missing_name() { + let field = + Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false) + .with_metadata( + [( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + r#"{ "shape": [100, 200, 500], }"#.to_owned(), + )] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic(expected = "FixedShapeTensor data type mismatch, expected FixedSizeList")] + fn invalid_type() { + let fixed_shape_tensor = + FixedShapeTensor::try_new(DataType::Int32, [100, 200, 500], None, None).unwrap(); + let field = Field::new_fixed_size_list( + "", + Field::new("", DataType::Float32, false), + i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"), + false, + ); + field.with_extension_type(fixed_shape_tensor); + } + + #[test] + #[should_panic(expected = "FixedShapeTensor extension types requires metadata")] + fn missing_metadata() { + let field = + Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false) + .with_metadata( + [( + EXTENSION_TYPE_NAME_KEY.to_owned(), + FixedShapeTensor::NAME.to_owned(), + )] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic( + expected = "FixedShapeTensor metadata deserialization failed: missing field `shape`" + )] + fn invalid_metadata() { + let fixed_shape_tensor = + FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, None).unwrap(); + let field = Field::new_fixed_size_list( + "", + Field::new("", DataType::Float32, false), + i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"), + false, + ) + .with_metadata( + [ + ( + EXTENSION_TYPE_NAME_KEY.to_owned(), + FixedShapeTensor::NAME.to_owned(), + ), + ( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + r#"{ "not-shape": [] }"#.to_owned(), + ), + ] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic( + expected = "FixedShapeTensor dimension names size mismatch, expected 3, found 2" + )] + fn invalid_metadata_dimension_names() { + FixedShapeTensor::try_new( + DataType::Float32, + [100, 200, 500], + Some(vec!["a".to_owned(), "b".to_owned()]), + None, + ) + .unwrap(); + } + + #[test] + #[should_panic(expected = "FixedShapeTensor permutations size mismatch, expected 3, found 2")] + fn invalid_metadata_permutations_len() { + FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, Some(vec![1, 0])) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3" + )] + fn invalid_metadata_permutations_values() { + FixedShapeTensor::try_new( + DataType::Float32, + [100, 200, 500], + None, + Some(vec![4, 3, 2]), + ) + .unwrap(); + } +} diff --git a/arrow-schema/src/extension/canonical/json.rs b/arrow-schema/src/extension/canonical/json.rs new file mode 100644 index 000000000000..0a8a1ae7e020 --- /dev/null +++ b/arrow-schema/src/extension/canonical/json.rs @@ -0,0 +1,198 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! JSON +//! +//! + +use serde::{Deserialize, Serialize}; + +use crate::{extension::ExtensionType, ArrowError, DataType}; + +/// The extension type for `JSON`. +/// +/// Extension name: `arrow.json`. +/// +/// The storage type of this extension is `String` or `LargeString` or +/// `StringView`. Only UTF-8 encoded JSON as specified in [rfc8259](https://datatracker.ietf.org/doc/html/rfc8259) +/// is supported. +/// +/// This type does not have any parameters. +/// +/// Metadata is either an empty string or a JSON string with an empty +/// object. In the future, additional fields may be added, but they are not +/// required to interpret the array. +/// +/// +#[derive(Debug, Clone, Default, PartialEq)] +pub struct Json(JsonMetadata); + +/// Empty object +#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +struct Empty {} + +/// Extension type metadata for [`Json`]. +#[derive(Debug, Default, Clone, PartialEq)] +pub struct JsonMetadata(Option); + +impl ExtensionType for Json { + const NAME: &'static str = "arrow.json"; + + type Metadata = JsonMetadata; + + fn metadata(&self) -> &Self::Metadata { + &self.0 + } + + fn serialize_metadata(&self) -> Option { + Some( + self.metadata() + .0 + .as_ref() + .map(serde_json::to_string) + .map(Result::unwrap) + .unwrap_or_else(|| "".to_owned()), + ) + } + + fn deserialize_metadata(metadata: Option<&str>) -> Result { + const ERR: &str = "Json extension type metadata is either an empty string or a JSON string with an empty object"; + metadata + .map_or_else( + || Err(ArrowError::InvalidArgumentError(ERR.to_owned())), + |metadata| { + match metadata { + // Empty string + "" => Ok(None), + value => serde_json::from_str::(value) + .map(Option::Some) + .map_err(|_| ArrowError::InvalidArgumentError(ERR.to_owned())), + } + }, + ) + .map(JsonMetadata) + } + + fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { + match data_type { + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Ok(()), + data_type => Err(ArrowError::InvalidArgumentError(format!( + "Json data type mismatch, expected one of Utf8, LargeUtf8, Utf8View, found {data_type}" + ))), + } + } + + fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result { + let json = Self(metadata); + json.supports_data_type(data_type)?; + Ok(json) + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "canonical_extension_types")] + use crate::extension::CanonicalExtensionType; + use crate::{ + extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, + Field, + }; + + use super::*; + + #[test] + fn valid() -> Result<(), ArrowError> { + let mut field = Field::new("", DataType::Utf8, false); + field.try_with_extension_type(Json::default())?; + assert_eq!( + field.metadata().get(EXTENSION_TYPE_METADATA_KEY), + Some(&"".to_owned()) + ); + assert_eq!( + field.try_extension_type::()?, + Json(JsonMetadata(None)) + ); + + let mut field = Field::new("", DataType::LargeUtf8, false); + field.try_with_extension_type(Json(JsonMetadata(Some(Empty {}))))?; + assert_eq!( + field.metadata().get(EXTENSION_TYPE_METADATA_KEY), + Some(&"{}".to_owned()) + ); + assert_eq!( + field.try_extension_type::()?, + Json(JsonMetadata(Some(Empty {}))) + ); + + let mut field = Field::new("", DataType::Utf8View, false); + field.try_with_extension_type(Json::default())?; + field.try_extension_type::()?; + #[cfg(feature = "canonical_extension_types")] + assert_eq!( + field.try_canonical_extension_type()?, + CanonicalExtensionType::Json(Json::default()) + ); + Ok(()) + } + + #[test] + #[should_panic(expected = "Field extension type name missing")] + fn missing_name() { + let field = Field::new("", DataType::Int8, false).with_metadata( + [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{}".to_owned())] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic(expected = "expected one of Utf8, LargeUtf8, Utf8View, found Null")] + fn invalid_type() { + Field::new("", DataType::Null, false).with_extension_type(Json::default()); + } + + #[test] + #[should_panic( + expected = "Json extension type metadata is either an empty string or a JSON string with an empty object" + )] + fn invalid_metadata() { + let field = Field::new("", DataType::Utf8, false).with_metadata( + [ + (EXTENSION_TYPE_NAME_KEY.to_owned(), Json::NAME.to_owned()), + (EXTENSION_TYPE_METADATA_KEY.to_owned(), "1234".to_owned()), + ] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic( + expected = "Json extension type metadata is either an empty string or a JSON string with an empty object" + )] + fn missing_metadata() { + let field = Field::new("", DataType::LargeUtf8, false).with_metadata( + [(EXTENSION_TYPE_NAME_KEY.to_owned(), Json::NAME.to_owned())] + .into_iter() + .collect(), + ); + field.extension_type::(); + } +} diff --git a/arrow-schema/src/extension/canonical/mod.rs b/arrow-schema/src/extension/canonical/mod.rs new file mode 100644 index 000000000000..3d66299ca885 --- /dev/null +++ b/arrow-schema/src/extension/canonical/mod.rs @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Canonical extension types. +//! +//! The Arrow columnar format allows defining extension types so as to extend +//! standard Arrow data types with custom semantics. Often these semantics will +//! be specific to a system or application. However, it is beneficial to share +//! the definitions of well-known extension types so as to improve +//! interoperability between different systems integrating Arrow columnar data. +//! +//! + +mod bool8; +pub use bool8::Bool8; +mod fixed_shape_tensor; +pub use fixed_shape_tensor::{FixedShapeTensor, FixedShapeTensorMetadata}; +mod json; +pub use json::{Json, JsonMetadata}; +mod opaque; +pub use opaque::{Opaque, OpaqueMetadata}; +mod uuid; +pub use uuid::Uuid; +mod variable_shape_tensor; +pub use variable_shape_tensor::{VariableShapeTensor, VariableShapeTensorMetadata}; + +use crate::{ArrowError, Field}; + +use super::ExtensionType; + +/// Canonical extension types. +/// +/// +#[non_exhaustive] +#[derive(Debug, Clone, PartialEq)] +pub enum CanonicalExtensionType { + /// The extension type for `FixedShapeTensor`. + /// + /// + FixedShapeTensor(FixedShapeTensor), + + /// The extension type for `VariableShapeTensor`. + /// + /// + VariableShapeTensor(VariableShapeTensor), + + /// The extension type for 'JSON'. + /// + /// + Json(Json), + + /// The extension type for `UUID`. + /// + /// + Uuid(Uuid), + + /// The extension type for `Opaque`. + /// + /// + Opaque(Opaque), + + /// The extension type for `Bool8`. + /// + /// + Bool8(Bool8), +} + +impl TryFrom<&Field> for CanonicalExtensionType { + type Error = ArrowError; + + fn try_from(value: &Field) -> Result { + // Canonical extension type names start with `arrow.` + match value.extension_type_name() { + // An extension type name with an `arrow.` prefix + Some(name) if name.starts_with("arrow.") => match name { + FixedShapeTensor::NAME => value.try_extension_type::().map(Into::into), + VariableShapeTensor::NAME => value.try_extension_type::().map(Into::into), + Json::NAME => value.try_extension_type::().map(Into::into), + Uuid::NAME => value.try_extension_type::().map(Into::into), + Opaque::NAME => value.try_extension_type::().map(Into::into), + Bool8::NAME => value.try_extension_type::().map(Into::into), + _ => Err(ArrowError::InvalidArgumentError(format!("Unsupported canonical extension type: {name}"))), + }, + // Name missing the expected prefix + Some(name) => Err(ArrowError::InvalidArgumentError(format!( + "Field extension type name mismatch, expected a name with an `arrow.` prefix, found {name}" + ))), + // Name missing + None => Err(ArrowError::InvalidArgumentError("Field extension type name missing".to_owned())), + } + } +} + +impl From for CanonicalExtensionType { + fn from(value: FixedShapeTensor) -> Self { + CanonicalExtensionType::FixedShapeTensor(value) + } +} + +impl From for CanonicalExtensionType { + fn from(value: VariableShapeTensor) -> Self { + CanonicalExtensionType::VariableShapeTensor(value) + } +} + +impl From for CanonicalExtensionType { + fn from(value: Json) -> Self { + CanonicalExtensionType::Json(value) + } +} + +impl From for CanonicalExtensionType { + fn from(value: Uuid) -> Self { + CanonicalExtensionType::Uuid(value) + } +} + +impl From for CanonicalExtensionType { + fn from(value: Opaque) -> Self { + CanonicalExtensionType::Opaque(value) + } +} + +impl From for CanonicalExtensionType { + fn from(value: Bool8) -> Self { + CanonicalExtensionType::Bool8(value) + } +} diff --git a/arrow-schema/src/extension/canonical/opaque.rs b/arrow-schema/src/extension/canonical/opaque.rs new file mode 100644 index 000000000000..1db7265cfde7 --- /dev/null +++ b/arrow-schema/src/extension/canonical/opaque.rs @@ -0,0 +1,201 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Opaque +//! +//! + +use serde::{Deserialize, Serialize}; + +use crate::{extension::ExtensionType, ArrowError, DataType}; + +/// The extension type for `Opaque`. +/// +/// Extension name: `arrow.opaque`. +/// +/// Opaque represents a type that an Arrow-based system received from an +/// external (often non-Arrow) system, but that it cannot interpret. In this +/// case, it can pass on Opaque to its clients to at least show that a field +/// exists and preserve metadata about the type from the other system. +/// +/// The storage type of this extension is any type. If there is no underlying +/// data, the storage type should be Null. +#[derive(Debug, Clone, PartialEq)] +pub struct Opaque(OpaqueMetadata); + +impl Opaque { + /// Returns a new `Opaque` extension type. + pub fn new(type_name: impl Into, vendor_name: impl Into) -> Self { + Self(OpaqueMetadata::new(type_name, vendor_name)) + } + + /// Returns the name of the unknown type in the external system. + pub fn type_name(&self) -> &str { + self.0.type_name() + } + + /// Returns the name of the external system. + pub fn vendor_name(&self) -> &str { + self.0.vendor_name() + } +} + +impl From for Opaque { + fn from(value: OpaqueMetadata) -> Self { + Self(value) + } +} + +/// Extension type metadata for [`Opaque`]. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct OpaqueMetadata { + /// Name of the unknown type in the external system. + type_name: String, + + /// Name of the external system. + vendor_name: String, +} + +impl OpaqueMetadata { + /// Returns a new `OpaqueMetadata`. + pub fn new(type_name: impl Into, vendor_name: impl Into) -> Self { + OpaqueMetadata { + type_name: type_name.into(), + vendor_name: vendor_name.into(), + } + } + + /// Returns the name of the unknown type in the external system. + pub fn type_name(&self) -> &str { + &self.type_name + } + + /// Returns the name of the external system. + pub fn vendor_name(&self) -> &str { + &self.vendor_name + } +} + +impl ExtensionType for Opaque { + const NAME: &'static str = "arrow.opaque"; + + type Metadata = OpaqueMetadata; + + fn metadata(&self) -> &Self::Metadata { + &self.0 + } + + fn serialize_metadata(&self) -> Option { + Some(serde_json::to_string(self.metadata()).expect("metadata serialization")) + } + + fn deserialize_metadata(metadata: Option<&str>) -> Result { + metadata.map_or_else( + || { + Err(ArrowError::InvalidArgumentError( + "Opaque extension types requires metadata".to_owned(), + )) + }, + |value| { + serde_json::from_str(value).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Opaque metadata deserialization failed: {e}" + )) + }) + }, + ) + } + + fn supports_data_type(&self, _data_type: &DataType) -> Result<(), ArrowError> { + // Any type + Ok(()) + } + + fn try_new(_data_type: &DataType, metadata: Self::Metadata) -> Result { + Ok(Self::from(metadata)) + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "canonical_extension_types")] + use crate::extension::CanonicalExtensionType; + use crate::{ + extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, + Field, + }; + + use super::*; + + #[test] + fn valid() -> Result<(), ArrowError> { + let opaque = Opaque::new("name", "vendor"); + let mut field = Field::new("", DataType::Null, false); + field.try_with_extension_type(opaque.clone())?; + assert_eq!(field.try_extension_type::()?, opaque); + #[cfg(feature = "canonical_extension_types")] + assert_eq!( + field.try_canonical_extension_type()?, + CanonicalExtensionType::Opaque(opaque) + ); + Ok(()) + } + + #[test] + #[should_panic(expected = "Field extension type name missing")] + fn missing_name() { + let field = Field::new("", DataType::Null, false).with_metadata( + [( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + r#"{ "type_name": "type", "vendor_name": "vendor" }"#.to_owned(), + )] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic(expected = "Opaque extension types requires metadata")] + fn missing_metadata() { + let field = Field::new("", DataType::Null, false).with_metadata( + [(EXTENSION_TYPE_NAME_KEY.to_owned(), Opaque::NAME.to_owned())] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic( + expected = "Opaque metadata deserialization failed: missing field `vendor_name`" + )] + fn invalid_metadata() { + let field = Field::new("", DataType::Null, false).with_metadata( + [ + (EXTENSION_TYPE_NAME_KEY.to_owned(), Opaque::NAME.to_owned()), + ( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + r#"{ "type_name": "no-vendor" }"#.to_owned(), + ), + ] + .into_iter() + .collect(), + ); + field.extension_type::(); + } +} diff --git a/arrow-schema/src/extension/canonical/uuid.rs b/arrow-schema/src/extension/canonical/uuid.rs new file mode 100644 index 000000000000..8b2e71b7b5aa --- /dev/null +++ b/arrow-schema/src/extension/canonical/uuid.rs @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! UUID +//! +//! + +use crate::{extension::ExtensionType, ArrowError, DataType}; + +/// The extension type for `UUID`. +/// +/// Extension name: `arrow.uuid`. +/// +/// The storage type of the extension is `FixedSizeBinary` with a length of +/// 16 bytes. +/// +/// Note: +/// A specific UUID version is not required or guaranteed. This extension +/// represents UUIDs as `FixedSizeBinary(16)` with big-endian notation and +/// does not interpret the bytes in any way. +/// +/// +#[derive(Debug, Default, Clone, Copy, PartialEq)] +pub struct Uuid; + +impl ExtensionType for Uuid { + const NAME: &'static str = "arrow.uuid"; + + type Metadata = (); + + fn metadata(&self) -> &Self::Metadata { + &() + } + + fn serialize_metadata(&self) -> Option { + None + } + + fn deserialize_metadata(metadata: Option<&str>) -> Result { + metadata.map_or_else( + || Ok(()), + |_| { + Err(ArrowError::InvalidArgumentError( + "Uuid extension type expects no metadata".to_owned(), + )) + }, + ) + } + + fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { + match data_type { + DataType::FixedSizeBinary(16) => Ok(()), + data_type => Err(ArrowError::InvalidArgumentError(format!( + "Uuid data type mismatch, expected FixedSizeBinary(16), found {data_type}" + ))), + } + } + + fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result { + Self.supports_data_type(data_type).map(|_| Self) + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "canonical_extension_types")] + use crate::extension::CanonicalExtensionType; + use crate::{ + extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, + Field, + }; + + use super::*; + + #[test] + fn valid() -> Result<(), ArrowError> { + let mut field = Field::new("", DataType::FixedSizeBinary(16), false); + field.try_with_extension_type(Uuid)?; + field.try_extension_type::()?; + #[cfg(feature = "canonical_extension_types")] + assert_eq!( + field.try_canonical_extension_type()?, + CanonicalExtensionType::Uuid(Uuid) + ); + Ok(()) + } + + #[test] + #[should_panic(expected = "Field extension type name missing")] + fn missing_name() { + let field = Field::new("", DataType::FixedSizeBinary(16), false); + field.extension_type::(); + } + + #[test] + #[should_panic(expected = "expected FixedSizeBinary(16), found FixedSizeBinary(8)")] + fn invalid_type() { + Field::new("", DataType::FixedSizeBinary(8), false).with_extension_type(Uuid); + } + + #[test] + #[should_panic(expected = "Uuid extension type expects no metadata")] + fn with_metadata() { + let field = Field::new("", DataType::FixedSizeBinary(16), false).with_metadata( + [ + (EXTENSION_TYPE_NAME_KEY.to_owned(), Uuid::NAME.to_owned()), + (EXTENSION_TYPE_METADATA_KEY.to_owned(), "".to_owned()), + ] + .into_iter() + .collect(), + ); + field.extension_type::(); + } +} diff --git a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs new file mode 100644 index 000000000000..804591776b2f --- /dev/null +++ b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs @@ -0,0 +1,551 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! VariableShapeTensor +//! +//! + +use serde::{Deserialize, Serialize}; + +use crate::{extension::ExtensionType, ArrowError, DataType, Field}; + +/// The extension type for `VariableShapeTensor`. +/// +/// Extension name: `arrow.variable_shape_tensor`. +/// +/// The storage type of the extension is: StructArray where struct is composed +/// of data and shape fields describing a single tensor per row: +/// - `data` is a List holding tensor elements (each list element is a single +/// tensor). The List’s value type is the value type of the tensor, such as +/// an integer or floating-point type. +/// - `shape` is a `FixedSizeList[ndim]` of the tensor shape where the +/// size of the list `ndim` is equal to the number of dimensions of the +/// tensor. +/// +/// Extension type parameters: +/// `value_type`: the Arrow data type of individual tensor elements. +/// +/// Optional parameters describing the logical layout: +/// - `dim_names`: explicit names to tensor dimensions as an array. The length +/// of it should be equal to the shape length and equal to the number of +/// dimensions. +/// `dim_names` can be used if the dimensions have well-known names and they +/// map to the physical layout (row-major). +/// - `permutation`: indices of the desired ordering of the original +/// dimensions, defined as an array. +/// The indices contain a permutation of the values `[0, 1, .., N-1]` where +/// `N` is the number of dimensions. The permutation indicates which +/// dimension of the logical layout corresponds to which dimension of the +/// physical tensor (the i-th dimension of the logical view corresponds to +/// the dimension with number `permutations[i]` of the physical tensor). +/// Permutation can be useful in case the logical order of the tensor is a +/// permutation of the physical order (row-major). +/// When logical and physical layout are equal, the permutation will always +/// be (`[0, 1, .., N-1]`) and can therefore be left out. +/// - `uniform_shape`: sizes of individual tensor’s dimensions which are +/// guaranteed to stay constant in uniform dimensions and can vary in non- +/// uniform dimensions. This holds over all tensors in the array. Sizes in +/// uniform dimensions are represented with int32 values, while sizes of the +/// non-uniform dimensions are not known in advance and are represented with +/// null. If `uniform_shape` is not provided it is assumed that all +/// dimensions are non-uniform. An array containing a tensor with shape (2, +/// 3, 4) and whose first and last dimensions are uniform would have +/// `uniform_shape` (2, null, 4). This allows for interpreting the tensor +/// correctly without accounting for uniform dimensions while still +/// permitting optional optimizations that take advantage of the uniformity. +/// +/// +#[derive(Debug, Clone, PartialEq)] +pub struct VariableShapeTensor { + /// The data type of individual tensor elements. + value_type: DataType, + + /// The number of dimensions of the tensor. + dimensions: usize, + + /// The metadata of this extension type. + metadata: VariableShapeTensorMetadata, +} + +impl VariableShapeTensor { + /// Returns a new variable shape tensor extension type. + /// + /// # Error + /// + /// Return an error if the provided dimension names, permutations or + /// uniform shapes are invalid. + pub fn try_new( + value_type: DataType, + dimensions: usize, + dimension_names: Option>, + permutations: Option>, + uniform_shapes: Option>>, + ) -> Result { + // TODO: are all data types are suitable as value type? + VariableShapeTensorMetadata::try_new( + dimensions, + dimension_names, + permutations, + uniform_shapes, + ) + .map(|metadata| Self { + value_type, + dimensions, + metadata, + }) + } + + /// Returns the value type of the individual tensor elements. + pub fn value_type(&self) -> &DataType { + &self.value_type + } + + /// Returns the number of dimensions in this variable shape tensor. + pub fn dimensions(&self) -> usize { + self.dimensions + } + + /// Returns the names of the dimensions in this variable shape tensor, if + /// set. + pub fn dimension_names(&self) -> Option<&[String]> { + self.metadata.dimension_names() + } + + /// Returns the indices of the desired ordering of the original + /// dimensions, if set. + pub fn permutations(&self) -> Option<&[usize]> { + self.metadata.permutations() + } + + /// Returns sizes of individual tensor’s dimensions which are guaranteed + /// to stay constant in uniform dimensions and can vary in non-uniform + /// dimensions. + pub fn uniform_shapes(&self) -> Option<&[Option]> { + self.metadata.uniform_shapes() + } +} + +/// Extension type metadata for [`VariableShapeTensor`]. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct VariableShapeTensorMetadata { + /// Explicit names to tensor dimensions. + dim_names: Option>, + + /// Indices of the desired ordering of the original dimensions. + permutations: Option>, + + /// Sizes of individual tensor’s dimensions which are guaranteed to stay + /// constant in uniform dimensions and can vary in non-uniform dimensions. + uniform_shape: Option>>, +} + +impl VariableShapeTensorMetadata { + /// Returns metadata for a variable shape tensor extension type. + /// + /// # Error + /// + /// Return an error if the provided dimension names, permutations or + /// uniform shapes are invalid. + pub fn try_new( + dimensions: usize, + dimension_names: Option>, + permutations: Option>, + uniform_shapes: Option>>, + ) -> Result { + let dim_names = dimension_names.map(|dimension_names| { + if dimension_names.len() != dimensions { + Err(ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len() + ))) + } else { + Ok(dimension_names) + } + }).transpose()?; + + let permutations = permutations + .map(|permutations| { + if permutations.len() != dimensions { + Err(ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor permutations size mismatch, expected {dimensions}, found {}", + permutations.len() + ))) + } else { + let mut sorted_permutations = permutations.clone(); + sorted_permutations.sort_unstable(); + if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) { + Err(ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}" + ))) + } else { + Ok(permutations) + } + } + }) + .transpose()?; + + let uniform_shape = uniform_shapes + .map(|uniform_shapes| { + if uniform_shapes.len() != dimensions { + Err(ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor uniform shapes size mismatch, expected {dimensions}, found {}", + uniform_shapes.len() + ))) + } else { + Ok(uniform_shapes) + } + }) + .transpose()?; + + Ok(Self { + dim_names, + permutations, + uniform_shape, + }) + } + + /// Returns the names of the dimensions in this variable shape tensor, if + /// set. + pub fn dimension_names(&self) -> Option<&[String]> { + self.dim_names.as_ref().map(AsRef::as_ref) + } + + /// Returns the indices of the desired ordering of the original dimensions, + /// if set. + pub fn permutations(&self) -> Option<&[usize]> { + self.permutations.as_ref().map(AsRef::as_ref) + } + + /// Returns sizes of individual tensor’s dimensions which are guaranteed + /// to stay constant in uniform dimensions and can vary in non-uniform + /// dimensions. + pub fn uniform_shapes(&self) -> Option<&[Option]> { + self.uniform_shape.as_ref().map(AsRef::as_ref) + } +} + +impl ExtensionType for VariableShapeTensor { + const NAME: &'static str = "arrow.variable_shape_tensor"; + + type Metadata = VariableShapeTensorMetadata; + + fn metadata(&self) -> &Self::Metadata { + &self.metadata + } + + fn serialize_metadata(&self) -> Option { + Some(serde_json::to_string(self.metadata()).expect("metadata serialization")) + } + + fn deserialize_metadata(metadata: Option<&str>) -> Result { + metadata.map_or_else( + || { + Err(ArrowError::InvalidArgumentError( + "VariableShapeTensor extension types requires metadata".to_owned(), + )) + }, + |value| { + serde_json::from_str(value).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor metadata deserialization failed: {e}" + )) + }) + }, + ) + } + + fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { + let expected = DataType::Struct( + [ + Field::new_list( + "data", + Field::new_list_field(self.value_type.clone(), false), + false, + ), + Field::new( + "shape", + DataType::new_fixed_size_list( + DataType::Int32, + i32::try_from(self.dimensions()).expect("overflow"), + false, + ), + false, + ), + ] + .into_iter() + .collect(), + ); + data_type + .equals_datatype(&expected) + .then_some(()) + .ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor data type mismatch, expected {expected}, found {data_type}" + )) + }) + } + + fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result { + match data_type { + DataType::Struct(fields) + if fields.len() == 2 + && matches!(fields.find("data"), Some((0, _))) + && matches!(fields.find("shape"), Some((1, _))) => + { + let shape_field = &fields[1]; + match shape_field.data_type() { + DataType::FixedSizeList(_, list_size) => { + let dimensions = usize::try_from(*list_size).expect("conversion failed"); + // Make sure the metadata is valid. + let metadata = VariableShapeTensorMetadata::try_new(dimensions, metadata.dim_names, metadata.permutations, metadata.uniform_shape)?; + let data_field = &fields[0]; + match data_field.data_type() { + DataType::List(field) => { + Ok(Self { + value_type: field.data_type().clone(), + dimensions, + metadata + }) + } + data_type => Err(ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor data type mismatch, expected List for data field, found {data_type}" + ))), + } + } + data_type => Err(ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor data type mismatch, expected FixedSizeList for shape field, found {data_type}" + ))), + } + } + data_type => Err(ArrowError::InvalidArgumentError(format!( + "VariableShapeTensor data type mismatch, expected Struct with 2 fields (data and shape), found {data_type}" + ))), + } + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "canonical_extension_types")] + use crate::extension::CanonicalExtensionType; + use crate::{ + extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, + Field, + }; + + use super::*; + + #[test] + fn valid() -> Result<(), ArrowError> { + let variable_shape_tensor = VariableShapeTensor::try_new( + DataType::Float32, + 3, + Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]), + Some(vec![2, 0, 1]), + Some(vec![Some(400), None, Some(3)]), + )?; + let mut field = Field::new_struct( + "", + vec![ + Field::new_list( + "data", + Field::new_list_field(DataType::Float32, false), + false, + ), + Field::new_fixed_size_list( + "shape", + Field::new("", DataType::Int32, false), + 3, + false, + ), + ], + false, + ); + field.try_with_extension_type(variable_shape_tensor.clone())?; + assert_eq!( + field.try_extension_type::()?, + variable_shape_tensor + ); + #[cfg(feature = "canonical_extension_types")] + assert_eq!( + field.try_canonical_extension_type()?, + CanonicalExtensionType::VariableShapeTensor(variable_shape_tensor) + ); + Ok(()) + } + + #[test] + #[should_panic(expected = "Field extension type name missing")] + fn missing_name() { + let field = Field::new_struct( + "", + vec![ + Field::new_list( + "data", + Field::new_list_field(DataType::Float32, false), + false, + ), + Field::new_fixed_size_list( + "shape", + Field::new("", DataType::Int32, false), + 3, + false, + ), + ], + false, + ) + .with_metadata( + [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{}".to_owned())] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic(expected = "VariableShapeTensor data type mismatch, expected Struct")] + fn invalid_type() { + let variable_shape_tensor = + VariableShapeTensor::try_new(DataType::Int32, 3, None, None, None).unwrap(); + let field = Field::new_struct( + "", + vec![ + Field::new_list( + "data", + Field::new_list_field(DataType::Float32, false), + false, + ), + Field::new_fixed_size_list( + "shape", + Field::new("", DataType::Int32, false), + 3, + false, + ), + ], + false, + ); + field.with_extension_type(variable_shape_tensor); + } + + #[test] + #[should_panic(expected = "VariableShapeTensor extension types requires metadata")] + fn missing_metadata() { + let field = Field::new_struct( + "", + vec![ + Field::new_list( + "data", + Field::new_list_field(DataType::Float32, false), + false, + ), + Field::new_fixed_size_list( + "shape", + Field::new("", DataType::Int32, false), + 3, + false, + ), + ], + false, + ) + .with_metadata( + [( + EXTENSION_TYPE_NAME_KEY.to_owned(), + VariableShapeTensor::NAME.to_owned(), + )] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic(expected = "VariableShapeTensor metadata deserialization failed: invalid type:")] + fn invalid_metadata() { + let field = Field::new_struct( + "", + vec![ + Field::new_list( + "data", + Field::new_list_field(DataType::Float32, false), + false, + ), + Field::new_fixed_size_list( + "shape", + Field::new("", DataType::Int32, false), + 3, + false, + ), + ], + false, + ) + .with_metadata( + [ + ( + EXTENSION_TYPE_NAME_KEY.to_owned(), + VariableShapeTensor::NAME.to_owned(), + ), + ( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + r#"{ "dim_names": [1, null, 3, 4] }"#.to_owned(), + ), + ] + .into_iter() + .collect(), + ); + field.extension_type::(); + } + + #[test] + #[should_panic( + expected = "VariableShapeTensor dimension names size mismatch, expected 3, found 2" + )] + fn invalid_metadata_dimension_names() { + VariableShapeTensor::try_new( + DataType::Float32, + 3, + Some(vec!["a".to_owned(), "b".to_owned()]), + None, + None, + ) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "VariableShapeTensor permutations size mismatch, expected 3, found 2" + )] + fn invalid_metadata_permutations_len() { + VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![1, 0]), None).unwrap(); + } + + #[test] + #[should_panic( + expected = "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3" + )] + fn invalid_metadata_permutations_values() { + VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![4, 3, 2]), None) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "VariableShapeTensor uniform shapes size mismatch, expected 3, found 2" + )] + fn invalid_metadata_uniform_shapes() { + VariableShapeTensor::try_new(DataType::Float32, 3, None, None, Some(vec![None, Some(1)])) + .unwrap(); + } +} diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs new file mode 100644 index 000000000000..c5119873af0c --- /dev/null +++ b/arrow-schema/src/extension/mod.rs @@ -0,0 +1,260 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Extension types. +//! +//!
This module is experimental. There might be breaking changes between minor releases.
+ +#[cfg(feature = "canonical_extension_types")] +mod canonical; +#[cfg(feature = "canonical_extension_types")] +pub use canonical::*; + +use crate::{ArrowError, DataType}; + +/// The metadata key for the string name identifying an [`ExtensionType`]. +pub const EXTENSION_TYPE_NAME_KEY: &str = "ARROW:extension:name"; + +/// The metadata key for a serialized representation of the [`ExtensionType`] +/// necessary to reconstruct the custom type. +pub const EXTENSION_TYPE_METADATA_KEY: &str = "ARROW:extension:metadata"; + +/// Extension types. +/// +/// User-defined “extension” types can be defined setting certain key value +/// pairs in the [`Field`] metadata structure. These extension keys are: +/// - [`EXTENSION_TYPE_NAME_KEY`] +/// - [`EXTENSION_TYPE_METADATA_KEY`] +/// +/// Canonical extension types support in this crate requires the +/// `canonical_extension_types` feature. +/// +/// Extension types may or may not use the [`EXTENSION_TYPE_METADATA_KEY`] +/// field. +/// +/// # Example +/// +/// The example below demonstrates how to implement this trait for a `Uuid` +/// type. Note this is not the canonical extension type for `Uuid`, which does +/// not include information about the `Uuid` version. +/// +/// ``` +/// # use arrow_schema::ArrowError; +/// # fn main() -> Result<(), ArrowError> { +/// use arrow_schema::{DataType, extension::ExtensionType, Field}; +/// use std::{fmt, str::FromStr}; +/// +/// /// The different Uuid versions. +/// #[derive(Clone, Copy, Debug, PartialEq)] +/// enum UuidVersion { +/// V1, +/// V2, +/// V3, +/// V4, +/// V5, +/// V6, +/// V7, +/// V8, +/// } +/// +/// // We'll use `Display` to serialize. +/// impl fmt::Display for UuidVersion { +/// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +/// write!( +/// f, +/// "{}", +/// match self { +/// Self::V1 => "V1", +/// Self::V2 => "V2", +/// Self::V3 => "V3", +/// Self::V4 => "V4", +/// Self::V5 => "V5", +/// Self::V6 => "V6", +/// Self::V7 => "V7", +/// Self::V8 => "V8", +/// } +/// ) +/// } +/// } +/// +/// // And `FromStr` to deserialize. +/// impl FromStr for UuidVersion { +/// type Err = ArrowError; +/// +/// fn from_str(s: &str) -> Result { +/// match s { +/// "V1" => Ok(Self::V1), +/// "V2" => Ok(Self::V2), +/// "V3" => Ok(Self::V3), +/// "V4" => Ok(Self::V4), +/// "V5" => Ok(Self::V5), +/// "V6" => Ok(Self::V6), +/// "V7" => Ok(Self::V7), +/// "V8" => Ok(Self::V8), +/// _ => Err(ArrowError::ParseError("Invalid UuidVersion".to_owned())), +/// } +/// } +/// } +/// +/// /// This is the extension type, not the container for Uuid values. It +/// /// stores the Uuid version (this is the metadata of this extension type). +/// #[derive(Clone, Copy, Debug, PartialEq)] +/// struct Uuid(UuidVersion); +/// +/// impl ExtensionType for Uuid { +/// // We use a namespace as suggested by the specification. +/// const NAME: &'static str = "myorg.example.uuid"; +/// +/// // The metadata type is the Uuid version. +/// type Metadata = UuidVersion; +/// +/// // We just return a reference to the Uuid version. +/// fn metadata(&self) -> &Self::Metadata { +/// &self.0 +/// } +/// +/// // We use the `Display` implementation to serialize the Uuid +/// // version. +/// fn serialize_metadata(&self) -> Option { +/// Some(self.0.to_string()) +/// } +/// +/// // We use the `FromStr` implementation to deserialize the Uuid +/// // version. +/// fn deserialize_metadata(metadata: Option<&str>) -> Result { +/// metadata.map_or_else( +/// || { +/// Err(ArrowError::InvalidArgumentError( +/// "Uuid extension type metadata missing".to_owned(), +/// )) +/// }, +/// str::parse, +/// ) +/// } +/// +/// // The only supported data type is `FixedSizeBinary(16)`. +/// fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { +/// match data_type { +/// DataType::FixedSizeBinary(16) => Ok(()), +/// data_type => Err(ArrowError::InvalidArgumentError(format!( +/// "Uuid data type mismatch, expected FixedSizeBinary(16), found {data_type}" +/// ))), +/// } +/// } +/// +/// // We should always check if the data type is supported before +/// // constructing the extension type. +/// fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result { +/// let uuid = Self(metadata); +/// uuid.supports_data_type(data_type)?; +/// Ok(uuid) +/// } +/// } +/// +/// // We can now construct the extension type. +/// let uuid_v1 = Uuid(UuidVersion::V1); +/// +/// // And add it to a field. +/// let mut field = +/// Field::new("", DataType::FixedSizeBinary(16), false).with_extension_type(uuid_v1); +/// +/// // And extract it from this field. +/// assert_eq!(field.try_extension_type::()?, uuid_v1); +/// +/// // When we try to add this to a field with an unsupported data type we +/// // get an error. +/// let result = Field::new("", DataType::Null, false).try_with_extension_type(uuid_v1); +/// assert!(result.is_err()); +/// # Ok(()) } +/// ``` +/// +/// +/// +/// [`Field`]: crate::Field +pub trait ExtensionType: Sized { + /// The name identifying this extension type. + /// + /// This is the string value that is used for the + /// [`EXTENSION_TYPE_NAME_KEY`] in the [`Field::metadata`] of a [`Field`] + /// to identify this extension type. + /// + /// We recommend that you use a “namespace”-style prefix for extension + /// type names to minimize the possibility of conflicts with multiple Arrow + /// readers and writers in the same application. For example, use + /// `myorg.name_of_type` instead of simply `name_of_type`. + /// + /// Extension names beginning with `arrow.` are reserved for canonical + /// extension types, they should not be used for third-party extension + /// types. + /// + /// Extension names are case-sensitive. + /// + /// [`Field`]: crate::Field + /// [`Field::metadata`]: crate::Field::metadata + const NAME: &'static str; + + /// The metadata type of this extension type. + /// + /// Implementations can use strongly or loosly typed data structures here + /// depending on the complexity of the metadata. + /// + /// Implementations can also use `Self` here if the extension type can be + /// constructed directly from its metadata. + /// + /// If an extension type defines no metadata it should use `()` to indicate + /// this. + type Metadata; + + /// Returns a reference to the metadata of this extension type, or `&()` if + /// if this extension type defines no metadata (`Self::Metadata=()`). + fn metadata(&self) -> &Self::Metadata; + + /// Returns the serialized representation of the metadata of this extension + /// type, or `None` if this extension type defines no metadata + /// (`Self::Metadata=()`). + /// + /// This is string value that is used for the + /// [`EXTENSION_TYPE_METADATA_KEY`] in the [`Field::metadata`] of a + /// [`Field`]. + /// + /// [`Field`]: crate::Field + /// [`Field::metadata`]: crate::Field::metadata + fn serialize_metadata(&self) -> Option; + + /// Deserialize the metadata of this extension type from the serialized + /// representation of the metadata. An extension type that defines no + /// metadata should expect `None` for the serialized metadata and return + /// `Ok(())`. + /// + /// This function should return an error when + /// - expected metadata is missing (for extensions types with non-optional + /// metadata) + /// - unexpected metadata is set (for extension types without metadata) + /// - deserialization of metadata fails + fn deserialize_metadata(metadata: Option<&str>) -> Result; + + /// Returns `OK())` iff the given data type is supported by this extension + /// type. + fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError>; + + /// Construct this extension type for a field with the given data type and + /// metadata. + /// + /// This should return an error if the given data type is not supported by + /// this extension type. + fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result; +} diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index 13bb7abf51b4..dbd671a62a3a 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -22,8 +22,13 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::datatype::DataType; +#[cfg(feature = "canonical_extension_types")] +use crate::extension::CanonicalExtensionType; use crate::schema::SchemaBuilder; -use crate::{Fields, UnionFields, UnionMode}; +use crate::{ + extension::{ExtensionType, EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, + Fields, UnionFields, UnionMode, +}; /// A reference counted [`Field`] pub type FieldRef = Arc; @@ -350,6 +355,167 @@ impl Field { self } + /// Returns the extension type name of this [`Field`], if set. + /// + /// This returns the value of [`EXTENSION_TYPE_NAME_KEY`], if set in + /// [`Field::metadata`]. If the key is missing, there is no extension type + /// name and this returns `None`. + /// + /// # Example + /// + /// ``` + /// # use arrow_schema::{DataType, extension::EXTENSION_TYPE_NAME_KEY, Field}; + /// + /// let field = Field::new("", DataType::Null, false); + /// assert_eq!(field.extension_type_name(), None); + /// + /// let field = Field::new("", DataType::Null, false).with_metadata( + /// [(EXTENSION_TYPE_NAME_KEY.to_owned(), "example".to_owned())] + /// .into_iter() + /// .collect(), + /// ); + /// assert_eq!(field.extension_type_name(), Some("example")); + /// ``` + pub fn extension_type_name(&self) -> Option<&str> { + self.metadata() + .get(EXTENSION_TYPE_NAME_KEY) + .map(String::as_ref) + } + + /// Returns the extension type metadata of this [`Field`], if set. + /// + /// This returns the value of [`EXTENSION_TYPE_METADATA_KEY`], if set in + /// [`Field::metadata`]. If the key is missing, there is no extension type + /// metadata and this returns `None`. + /// + /// # Example + /// + /// ``` + /// # use arrow_schema::{DataType, extension::EXTENSION_TYPE_METADATA_KEY, Field}; + /// + /// let field = Field::new("", DataType::Null, false); + /// assert_eq!(field.extension_type_metadata(), None); + /// + /// let field = Field::new("", DataType::Null, false).with_metadata( + /// [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "example".to_owned())] + /// .into_iter() + /// .collect(), + /// ); + /// assert_eq!(field.extension_type_metadata(), Some("example")); + /// ``` + pub fn extension_type_metadata(&self) -> Option<&str> { + self.metadata() + .get(EXTENSION_TYPE_METADATA_KEY) + .map(String::as_ref) + } + + /// Returns an instance of the given [`ExtensionType`] of this [`Field`], + /// if set in the [`Field::metadata`]. + /// + /// # Error + /// + /// Returns an error if + /// - this field does not have the name of this extension type + /// ([`ExtensionType::NAME`]) in the [`Field::metadata`] (mismatch or + /// missing) + /// - the deserialization of the metadata + /// ([`ExtensionType::deserialize_metadata`]) fails + /// - the construction of the extension type ([`ExtensionType::try_new`]) + /// fail (for example when the [`Field::data_type`] is not supported by + /// the extension type ([`ExtensionType::supports_data_type`])) + pub fn try_extension_type(&self) -> Result { + // Check the extension name in the metadata + match self.extension_type_name() { + // It should match the name of the given extension type + Some(name) if name == E::NAME => { + // Deserialize the metadata and try to construct the extension + // type + E::deserialize_metadata(self.extension_type_metadata()) + .and_then(|metadata| E::try_new(self.data_type(), metadata)) + } + // Name mismatch + Some(name) => Err(ArrowError::InvalidArgumentError(format!( + "Field extension type name mismatch, expected {}, found {name}", + E::NAME + ))), + // Name missing + None => Err(ArrowError::InvalidArgumentError( + "Field extension type name missing".to_owned(), + )), + } + } + + /// Returns an instance of the given [`ExtensionType`] of this [`Field`], + /// panics if this [`Field`] does not have this extension type. + /// + /// # Panic + /// + /// This calls [`Field::try_extension_type`] and panics when it returns an + /// error. + pub fn extension_type(&self) -> E { + self.try_extension_type::() + .unwrap_or_else(|e| panic!("{e}")) + } + + /// Updates the metadata of this [`Field`] with the [`ExtensionType::NAME`] + /// and [`ExtensionType::metadata`] of the given [`ExtensionType`], if the + /// given extension type supports the [`Field::data_type`] of this field + /// ([`ExtensionType::supports_data_type`]). + /// + /// If the given extension type defines no metadata, a previously set + /// value of [`EXTENSION_TYPE_METADATA_KEY`] is cleared. + /// + /// # Error + /// + /// This functions returns an error if the data type of this field does not + /// match any of the supported storage types of the given extension type. + pub fn try_with_extension_type( + &mut self, + extension_type: E, + ) -> Result<(), ArrowError> { + // Make sure the data type of this field is supported + extension_type.supports_data_type(&self.data_type)?; + + self.metadata + .insert(EXTENSION_TYPE_NAME_KEY.to_owned(), E::NAME.to_owned()); + match extension_type.serialize_metadata() { + Some(metadata) => self + .metadata + .insert(EXTENSION_TYPE_METADATA_KEY.to_owned(), metadata), + // If this extension type has no metadata, we make sure to + // clear previously set metadata. + None => self.metadata.remove(EXTENSION_TYPE_METADATA_KEY), + }; + + Ok(()) + } + + /// Updates the metadata of this [`Field`] with the [`ExtensionType::NAME`] + /// and [`ExtensionType::metadata`] of the given [`ExtensionType`]. + /// + /// # Panics + /// + /// This calls [`Field::try_with_extension_type`] and panics when it + /// returns an error. + pub fn with_extension_type(mut self, extension_type: E) -> Self { + self.try_with_extension_type(extension_type) + .unwrap_or_else(|e| panic!("{e}")); + self + } + + /// Returns the [`CanonicalExtensionType`] of this [`Field`], if set. + /// + /// # Error + /// + /// Returns an error if + /// - this field does have a canonical extension type (mismatch or missing) + /// - the canonical extension is not supported + /// - the construction of the extension type fails + #[cfg(feature = "canonical_extension_types")] + pub fn try_canonical_extension_type(&self) -> Result { + CanonicalExtensionType::try_from(self) + } + /// Indicates whether this [`Field`] supports null values. #[inline] pub const fn is_nullable(&self) -> bool { diff --git a/arrow-schema/src/lib.rs b/arrow-schema/src/lib.rs index d06382fbcdf7..a83e23e27592 100644 --- a/arrow-schema/src/lib.rs +++ b/arrow-schema/src/lib.rs @@ -25,6 +25,7 @@ use std::fmt::Display; mod datatype_parse; mod error; pub use error::*; +pub mod extension; mod field; pub use field::*; mod fields; diff --git a/arrow-select/src/dictionary.rs b/arrow-select/src/dictionary.rs index 2a532600b6cc..c363b99920a7 100644 --- a/arrow-select/src/dictionary.rs +++ b/arrow-select/src/dictionary.rs @@ -315,7 +315,7 @@ mod tests { assert_eq!(merged.values.as_ref(), &expected); assert_eq!(merged.key_mappings.len(), 2); assert_eq!(&merged.key_mappings[0], &[0, 0, 0, 1, 0]); - assert_eq!(&merged.key_mappings[1], &[]); + assert_eq!(&merged.key_mappings[1], &[] as &[i32; 0]); } #[test] diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 1b01dcd25e54..88231b7f6160 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -80,6 +80,7 @@ force_validate = ["arrow-array/force_validate", "arrow-data/force_validate"] # Enable ffi support ffi = ["arrow-schema/ffi", "arrow-data/ffi", "arrow-array/ffi"] chrono-tz = ["arrow-array/chrono-tz"] +canonical_extension_types = ["arrow-schema/canonical_extension_types"] [dev-dependencies] chrono = { workspace = true } diff --git a/arrow/README.md b/arrow/README.md index 79aefaae9053..64d9eb980e60 100644 --- a/arrow/README.md +++ b/arrow/README.md @@ -61,6 +61,7 @@ The `arrow` crate provides the following features which may be enabled in your ` - `chrono-tz` - support of parsing timezone using [chrono-tz](https://docs.rs/chrono-tz/0.6.0/chrono_tz/) - `ffi` - bindings for the Arrow C [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) - `pyarrow` - bindings for pyo3 to call arrow-rs from python +- `canonical_extension_types` - definitions for [canonical extension types](https://arrow.apache.org/docs/format/CanonicalExtensions.html#format-canonical-extensions) ## Arrow Feature Status diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index 54992d864d85..00d4c5b750f8 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -103,6 +103,8 @@ default = ["arrow", "snap", "brotli", "flate2", "lz4", "zstd", "base64", "simdut lz4 = ["lz4_flex"] # Enable arrow reader/writer APIs arrow = ["base64", "arrow-array", "arrow-buffer", "arrow-cast", "arrow-data", "arrow-schema", "arrow-select", "arrow-ipc"] +# Enable support for arrow canonical extension types +arrow_canonical_extension_types = ["arrow-schema?/canonical_extension_types"] # Enable CLI tools cli = ["json", "base64", "clap", "arrow-csv", "serde"] # Enable JSON APIs diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 8be2439002be..8b3e92251bd1 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -23,6 +23,8 @@ use std::collections::HashMap; use std::sync::Arc; use arrow_ipc::writer; +#[cfg(feature = "arrow_canonical_extension_types")] +use arrow_schema::extension::{Json, Uuid}; use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit}; use crate::basic::{ @@ -380,12 +382,26 @@ pub fn parquet_to_arrow_field(parquet_column: &ColumnDescriptor) -> Result ret.try_with_extension_type(Uuid)?, + LogicalType::Json => ret.try_with_extension_type(Json::default())?, + _ => {} + } + } + if !meta.is_empty() { ret.set_metadata(meta); } @@ -590,6 +606,16 @@ fn arrow_to_parquet_type(field: &Field, coerce_types: bool) -> Result { .with_repetition(repetition) .with_id(id) .with_length(*length) + .with_logical_type( + #[cfg(feature = "arrow_canonical_extension_types")] + // If set, map arrow uuid extension type to parquet uuid logical type. + field + .try_extension_type::() + .ok() + .map(|_| LogicalType::Uuid), + #[cfg(not(feature = "arrow_canonical_extension_types"))] + None, + ) .build() } DataType::BinaryView => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) @@ -623,13 +649,35 @@ fn arrow_to_parquet_type(field: &Field, coerce_types: bool) -> Result { } DataType::Utf8 | DataType::LargeUtf8 => { Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) - .with_logical_type(Some(LogicalType::String)) + .with_logical_type({ + #[cfg(feature = "arrow_canonical_extension_types")] + { + // Use the Json logical type if the canonical Json + // extension type is set on this field. + field + .try_extension_type::() + .map_or(Some(LogicalType::String), |_| Some(LogicalType::Json)) + } + #[cfg(not(feature = "arrow_canonical_extension_types"))] + Some(LogicalType::String) + }) .with_repetition(repetition) .with_id(id) .build() } DataType::Utf8View => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) - .with_logical_type(Some(LogicalType::String)) + .with_logical_type({ + #[cfg(feature = "arrow_canonical_extension_types")] + { + // Use the Json logical type if the canonical Json + // extension type is set on this field. + field + .try_extension_type::() + .map_or(Some(LogicalType::String), |_| Some(LogicalType::Json)) + } + #[cfg(not(feature = "arrow_canonical_extension_types"))] + Some(LogicalType::String) + }) .with_repetition(repetition) .with_id(id) .build(), @@ -2163,4 +2211,52 @@ mod tests { fn test_get_arrow_schema_from_metadata() { assert!(get_arrow_schema_from_metadata("").is_err()); } + + #[test] + #[cfg(feature = "arrow_canonical_extension_types")] + fn arrow_uuid_to_parquet_uuid() -> Result<()> { + let arrow_schema = Schema::new(vec![Field::new( + "uuid", + DataType::FixedSizeBinary(16), + false, + ) + .with_extension_type(Uuid)]); + + let parquet_schema = ArrowSchemaConverter::new().convert(&arrow_schema)?; + + assert_eq!( + parquet_schema.column(0).logical_type(), + Some(LogicalType::Uuid) + ); + + // TODO: roundtrip + // let arrow_schema = parquet_to_arrow_schema(&parquet_schema, None)?; + // assert_eq!(arrow_schema.field(0).try_extension_type::()?, Uuid); + + Ok(()) + } + + #[test] + #[cfg(feature = "arrow_canonical_extension_types")] + fn arrow_json_to_parquet_json() -> Result<()> { + let arrow_schema = Schema::new(vec![ + Field::new("json", DataType::Utf8, false).with_extension_type(Json::default()) + ]); + + let parquet_schema = ArrowSchemaConverter::new().convert(&arrow_schema)?; + + assert_eq!( + parquet_schema.column(0).logical_type(), + Some(LogicalType::Json) + ); + + // TODO: roundtrip + // let arrow_schema = parquet_to_arrow_schema(&parquet_schema, None)?; + // assert_eq!( + // arrow_schema.field(0).try_extension_type::()?, + // Json::default() + // ); + + Ok(()) + } }