Skip to content

Commit

Permalink
feat: allow casting in alter_columns() (#1909)
Browse files Browse the repository at this point in the history
This allows casting a column in place. For example, a user could change
a `fixed_size_list<f32, _>` column to a `fixed_size_list<f16, _>`. In a
future PR, we will make sure the indices can be preserved as part of the
transaction.
  • Loading branch information
wjones127 authored Feb 22, 2024
1 parent dc01633 commit 45c158d
Show file tree
Hide file tree
Showing 12 changed files with 532 additions and 67 deletions.
20 changes: 19 additions & 1 deletion python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,18 @@ def join(
raise NotImplementedError("Versioning not yet supported in Rust")

def alter_columns(self, *alterations: Iterable[Dict[str, Any]]):
"""Alter column names and nullability.
"""Alter column name, data type, and nullability.
Columns that are renamed can keep any indices that are on them. However, if
the column is cast to a different type, its indices will be dropped.
Column types can be upcasted (such as int32 to int64) or downcasted
(such as int64 to int32). However, downcasting will fail if there are
any values that cannot be represented in the new type. In general,
columns can be casted to same general type: integers to integers,
floats to floats, and strings to strings. However, strings, binary, and
list columns can be casted between their size variants. For example,
string to large string, binary to large binary, and list to large list.
Parameters
----------
Expand All @@ -638,6 +649,9 @@ def alter_columns(self, *alterations: Iterable[Dict[str, Any]]):
nullability is not changed. Only non-nullable columns can be changed
to nullable. Currently, you cannot change a nullable column to
non-nullable.
- "data_type": pyarrow.DataType, optional
The new data type to cast the column to. If not specified, the column
data type is not changed.
Examples
--------
Expand All @@ -654,6 +668,10 @@ def alter_columns(self, *alterations: Iterable[Dict[str, Any]]):
0 1 a
1 2 b
2 3 c
>>> dataset.alter_columns({"path": "x", "data_type": pa.int32()})
>>> dataset.schema
x: int32
b: string
"""
self._ds.alter_columns(list(alterations))

Expand Down
37 changes: 35 additions & 2 deletions python/python/tests/test_schema_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_alter_columns(tmp_path: Path):
pa.field("b", pa.string(), nullable=False),
])
tab = pa.table(
{"a": pa.array([1, 2, 3]), "b": pa.array(["a", "b", "c"])}, schema=schema
{"a": pa.array([1, 2, 1024]), "b": pa.array(["a", "b", "c"])}, schema=schema
)

dataset = lance.write_dataset(tab, tmp_path)
Expand All @@ -177,7 +177,40 @@ def test_alter_columns(tmp_path: Path):
assert dataset.schema == expected_schema

expected_tab = pa.table(
{"x": pa.array([1, 2, 3]), "y": pa.array(["a", "b", "c"])},
{"x": pa.array([1, 2, 1024]), "y": pa.array(["a", "b", "c"])},
schema=expected_schema,
)
assert dataset.to_table() == expected_tab

dataset.alter_columns(
{"path": "x", "data_type": pa.int32()},
{"path": "y", "data_type": pa.large_string()},
)
expected_schema = pa.schema([
pa.field("x", pa.int32()),
pa.field("y", pa.large_string(), nullable=False),
])
assert dataset.schema == expected_schema

expected_tab = pa.table(
{"x": pa.array([1, 2, 1024], type=pa.int32()), "y": pa.array(["a", "b", "c"])},
schema=expected_schema,
)
assert dataset.to_table() == expected_tab
with pytest.raises(Exception, match="Can't cast value 1024 to type Int8"):
dataset.alter_columns({"path": "x", "data_type": pa.int8()})

with pytest.raises(Exception, match='Cannot cast column "x" from Int32 to Utf8'):
dataset.alter_columns({"path": "x", "data_type": pa.string()})

with pytest.raises(Exception, match='Column "q" does not exist'):
dataset.alter_columns({"path": "q", "name": "z"})

with pytest.raises(ValueError, match="Unknown key: type"):
dataset.alter_columns({"path": "x", "type": "string"})

with pytest.raises(
ValueError,
match="At least one of name, nullable, or data_type must be specified",
):
dataset.alter_columns({"path": "x"})
24 changes: 24 additions & 0 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,13 +653,37 @@ impl Dataset {
obj.get_item("name")?.map(|n| n.extract()).transpose()?;
let nullable: Option<bool> =
obj.get_item("nullable")?.map(|n| n.extract()).transpose()?;
let data_type: Option<PyArrowType<DataType>> = obj
.get_item("data_type")?
.map(|n| n.extract())
.transpose()?;

for key in obj.keys().iter().map(|k| k.extract::<String>()) {
let k = key?;
if k != "path" && k != "name" && k != "nullable" && k != "data_type" {
return Err(PyValueError::new_err(format!(
"Unknown key: {}. Valid keys are name, nullable, and data_type.",
k
)));
}
}

if name.is_none() && nullable.is_none() && data_type.is_none() {
return Err(PyValueError::new_err(
"At least one of name, nullable, or data_type must be specified",
));
}

let mut alteration = ColumnAlteration::new(path);
if let Some(name) = name {
alteration = alteration.rename(name);
}
if let Some(nullable) = nullable {
alteration = alteration.set_nullable(nullable);
}
if let Some(data_type) = data_type {
alteration = alteration.cast_to(data_type.0);
}
Ok(alteration)
})
.collect::<PyResult<Vec<_>>>()?;
Expand Down
2 changes: 1 addition & 1 deletion python/src/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ impl FileFragment {
fn updater(&self, columns: Option<Vec<String>>) -> PyResult<Updater> {
let cols = columns.as_deref();
let inner = RT
.block_on(None, async { self.fragment.updater(cols).await })?
.block_on(None, async { self.fragment.updater(cols, None).await })?
.map_err(|err| PyIOError::new_err(err.to_string()))?;
Ok(Updater::new(inner))
}
Expand Down
1 change: 1 addition & 0 deletions rust/lance-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ crate-type = ["cdylib", "rlib"]
arrow-array = { workspace = true }
arrow-buffer = { workspace = true }
arrow-data = { workspace = true }
arrow-cast ={ workspace = true }
arrow-schema = { workspace = true }
arrow-select = { workspace = true }
half = { workspace = true }
Expand Down
1 change: 1 addition & 0 deletions rust/lance-arrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub use schema::*;
pub mod bfloat16;
pub mod floats;
pub use floats::*;
pub mod cast;

type Result<T> = std::result::Result<T, ArrowError>;

Expand Down
48 changes: 39 additions & 9 deletions rust/lance-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,45 @@ impl Field {
.collect::<Vec<_>>();
if !children.is_empty() || filter(self) {
Some(Self {
name: self.name.clone(),
id: self.id,
parent_id: self.parent_id,
logical_type: self.logical_type.clone(),
metadata: self.metadata.clone(),
encoding: self.encoding.clone(),
nullable: self.nullable,
children,
dictionary: self.dictionary.clone(),
..self.clone()
})
} else {
None
}
}

/// Create a new field by selecting fields by their ids.
///
/// If a field has it's id in the list of ids then it will be included
/// in the new field. If a field is selected, all of it's parents will be
/// and all of it's children will be included.
///
/// For example, for the schema:
///
/// ```text
/// 0: x struct {
/// 1: y int32
/// 2: l list {
/// 3: z int32
/// }
/// }
/// ```
///
/// If the ids are `[2]`, then this will include the parent `0` and the
/// child `3`.
pub(crate) fn project_by_ids(&self, ids: &[i32]) -> Option<Self> {
let children = self
.children
.iter()
.filter_map(|c| c.project_by_ids(ids))
.collect::<Vec<_>>();
if ids.contains(&self.id) {
Some(self.clone())
} else if !children.is_empty() {
Some(Self {
children,
..self.clone()
})
} else {
None
Expand Down Expand Up @@ -610,7 +640,7 @@ impl Field {
}

/// Recursively set field ID and parent ID for this field and all its children.
pub(super) fn set_id(&mut self, parent_id: i32, id_seed: &mut i32) {
pub fn set_id(&mut self, parent_id: i32, id_seed: &mut i32) {
self.parent_id = parent_id;
if self.id < 0 {
self.id = *id_seed;
Expand Down
16 changes: 14 additions & 2 deletions rust/lance-core/src/datatypes/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ impl Schema {
let filtered_fields = self
.fields
.iter()
.filter_map(|f| f.project_by_filter(&|f| column_ids.contains(&f.id)))
.filter_map(|f| f.project_by_ids(column_ids))
.collect();
Self {
fields: filtered_fields,
Expand Down Expand Up @@ -605,7 +605,7 @@ mod tests {
ArrowField::new("c", DataType::Float64, false),
]);
let schema = Schema::try_from(&arrow_schema).unwrap();
let projected = schema.project_by_ids(&[1, 2, 4, 5]);
let projected = schema.project_by_ids(&[2, 4, 5]);

let expected_arrow_schema = ArrowSchema::new(vec![
ArrowField::new(
Expand All @@ -631,6 +631,18 @@ mod tests {
true,
)]);
assert_eq!(ArrowSchema::from(&projected), expected_arrow_schema);

let projected = schema.project_by_ids(&[1]);
let expected_arrow_schema = ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
)]);
assert_eq!(ArrowSchema::from(&projected), expected_arrow_schema);
}

#[test]
Expand Down
Loading

0 comments on commit 45c158d

Please sign in to comment.