diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 4fd711320a..d32a01ca9e 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -71,6 +71,24 @@ def test_input_data(tmp_path: Path, schema, data): assert dataset.to_table() == input_data[0][1] +def test_roundtrip_types(tmp_path: Path): + table = pa.table({ + "dict": pa.array(["a", "b", "a"], pa.dictionary(pa.int8(), pa.string())), + # PyArrow doesn't support creating large_string dictionaries easily. + "large_dict": pa.DictionaryArray.from_arrays( + pa.array([0, 1, 1], pa.int8()), pa.array(["foo", "bar"], pa.large_string()) + ), + "list": pa.array([["a", "b"], ["c", "d"], ["e", "f"]], pa.list_(pa.string())), + "large_list": pa.array( + [["a", "b"], ["c", "d"], ["e", "f"]], pa.large_list(pa.string()) + ), + }) + + dataset = lance.write_dataset(table, tmp_path) + assert dataset.schema == table.schema + assert dataset.to_table() == table + + def test_dataset_overwrite(tmp_path: Path): table1 = pa.Table.from_pylist([{"a": 1, "b": 2}, {"a": 10, "b": 20}]) base_dir = tmp_path / "test" diff --git a/rust/lance-file/src/datatypes.rs b/rust/lance-file/src/datatypes.rs index 8522516ee4..c54b3cb64a 100644 --- a/rust/lance-file/src/datatypes.rs +++ b/rust/lance-file/src/datatypes.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use arrow_schema::DataType; use async_recursion::async_recursion; use lance_arrow::bfloat16::ARROW_EXT_NAME_KEY; +use lance_arrow::DataTypeExt; use lance_core::datatypes::{Dictionary, Encoding, Field, LogicalType, Schema}; use lance_core::{Error, Result}; use lance_io::traits::Reader; @@ -194,7 +195,7 @@ async fn load_field_dictionary<'a>(field: &mut Field, reader: &dyn Reader) -> Re if let Some(dict_info) = field.dictionary.as_mut() { use DataType::*; match value_type.as_ref() { - Utf8 | Binary => { + _ if value_type.is_binary_like() => { dict_info.values = Some( read_binary_array( reader,