Skip to content

Commit

Permalink
feat: add alter columns for names and nullability (#1903)
Browse files Browse the repository at this point in the history
Initial pass on `alter_columns()` API. This allows renaming columns and
making them nullable. A future PR will allow casting the type of column.
  • Loading branch information
wjones127 authored Feb 5, 2024
1 parent 912b8fc commit 33bf813
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 0 deletions.
37 changes: 37 additions & 0 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,43 @@ def join(
"""
raise NotImplementedError("Versioning not yet supported in Rust")

def alter_columns(self, *alterations: Iterable[Dict[str, Any]]):
"""Alter column names and nullability.
Parameters
----------
alterations : Iterable[Dict[str, Any]]
A sequence of dictionaries, each with the following keys:
- "path": str
The column path to alter. For a top-level column, this is the name.
For a nested column, this is the dot-separated path, e.g. "a.b.c".
- "name": str, optional
The new name of the column. If not specified, the column name is
not changed.
- "nullable": bool, optional
Whether the column should be nullable. If not specified, the column
nullability is not changed. Only non-nullable columns can be changed
to nullable. Currently, you cannot change a nullable column to
non-nullable.
Examples
--------
>>> import lance
>>> import pyarrow as pa
>>> schema = pa.schema([pa.field('a', pa.int64()),
... pa.field('b', pa.string(), nullable=False)])
>>> table = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]})
>>> dataset = lance.write_dataset(table, "example")
>>> dataset.alter_columns({"path": "a", "name": "x"},
... {"path": "b", "nullable": True})
>>> dataset.to_table().to_pandas()
x b
0 1 a
1 2 b
2 3 c
"""
self._ds.alter_columns(list(alterations))

def merge(
self,
data_obj: ReaderLike,
Expand Down
29 changes: 29 additions & 0 deletions python/python/tests/test_schema_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,32 @@ def test_query_after_merge(tmp_path):
dataset.to_table(
nearest=dict(column="vec", k=10, q=np.random.rand(128).astype("float32"))
)


def test_alter_columns(tmp_path: Path):
schema = pa.schema([
pa.field("a", pa.int64(), nullable=False),
pa.field("b", pa.string(), nullable=False),
])
tab = pa.table(
{"a": pa.array([1, 2, 3]), "b": pa.array(["a", "b", "c"])}, schema=schema
)

dataset = lance.write_dataset(tab, tmp_path)

dataset.alter_columns(
{"path": "a", "name": "x", "nullable": True},
{"path": "b", "name": "y"},
)

expected_schema = pa.schema([
pa.field("x", pa.int64()),
pa.field("y", pa.string(), nullable=False),
])
assert dataset.schema == expected_schema

expected_tab = pa.table(
{"x": pa.array([1, 2, 3]), "y": pa.array(["a", "b", "c"])},
schema=expected_schema,
)
assert dataset.to_table() == expected_tab
35 changes: 35 additions & 0 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use chrono::Duration;

use futures::{StreamExt, TryFutureExt};
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::ColumnAlteration;
use lance::dataset::{
fragment::FileFragment as LanceFileFragment, progress::WriteFragmentProgress,
scanner::Scanner as LanceScanner, transaction::Operation as LanceOperation,
Expand Down Expand Up @@ -629,6 +630,40 @@ impl Dataset {
Ok(PyArrowType(Box::new(LanceReader::from_stream(stream))))
}

fn alter_columns(&mut self, alterations: &PyList) -> PyResult<()> {
let alterations = alterations
.iter()
.map(|obj| {
let obj = obj.downcast::<PyDict>()?;
let path: String = obj
.get_item("path")?
.ok_or_else(|| PyValueError::new_err("path is required"))?
.extract()?;
let name: Option<String> =
obj.get_item("name")?.map(|n| n.extract()).transpose()?;
let nullable: Option<bool> =
obj.get_item("nullable")?.map(|n| n.extract()).transpose()?;
let mut alteration = ColumnAlteration::new(path);
if let Some(name) = name {
alteration = alteration.rename(name);
}
if let Some(nullable) = nullable {
alteration = alteration.set_nullable(nullable);
}
Ok(alteration)
})
.collect::<PyResult<Vec<_>>>()?;

let mut new_self = self.ds.as_ref().clone();
new_self = RT
.spawn(None, async move {
new_self.alter_columns(&alterations).await.map(|_| new_self)
})?
.map_err(|err| PyIOError::new_err(err.to_string()))?;
self.ds = Arc::new(new_self);
Ok(())
}

fn merge(
&mut self,
reader: PyArrowType<ArrowArrayStreamReader>,
Expand Down
19 changes: 19 additions & 0 deletions rust/lance-core/src/datatypes/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,32 @@ impl Schema {
/// to distinguish from nested fields.
// TODO: pub(crate)
pub fn validate(&self) -> Result<()> {
let mut seen_names = HashSet::new();

for field in self.fields.iter() {
if field.name.contains('.') {
return Err(Error::Schema{message:format!(
"Top level field {} cannot contain `.`. Maybe you meant to create a struct field?",
field.name.clone()
), location: location!(),});
}

let column_path = self
.field_ancestry_by_id(field.id)
.unwrap()
.iter()
.map(|f| f.name.as_str())
.collect::<Vec<_>>()
.join(".");
if !seen_names.insert(column_path.clone()) {
return Err(Error::Schema {
message: format!(
"Duplicate field name \"{}\" in schema:\n {:#?}",
column_path, self
),
location: location!(),
});
}
}

// Check for duplicate field ids
Expand Down
197 changes: 197 additions & 0 deletions rust/lance/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,40 @@ pub enum NewColumnTransform {
SqlExpressions(Vec<(String, String)>),
}

/// Definition of a change to a column in a dataset
pub struct ColumnAlteration {
/// Path to the existing column to be altered.
pub path: String,
/// The new name of the column. If None, the column name will not be changed.
pub rename: Option<String>,
/// Whether the column is nullable. If None, the nullability will not be changed.
pub nullable: Option<bool>,
// TODO: support changing data type.
}

impl ColumnAlteration {
pub fn new(path: String) -> Self {
Self {
path,
rename: None,
// data_type: None,
nullable: None,
}
}

pub fn rename(mut self, name: String) -> Self {
self.rename = Some(name);
self
}

pub fn set_nullable(mut self, nullable: bool) -> Self {
self.nullable = Some(nullable);
self
}
}

// TODO: move all schema evolution methods to this impl and provide a dedicated
// docs section to describe the schema evolution methods.
impl Dataset {
/// Append new columns to the dataset.
pub async fn add_columns(
Expand Down Expand Up @@ -1657,6 +1691,68 @@ impl Dataset {
Ok(fragments)
}

/// Modify columns in the dataset, changing their name, type, or nullability.
///
/// If a column has an index, it's index will be preserved.
pub async fn alter_columns(&mut self, alterations: &[ColumnAlteration]) -> Result<()> {
// Validate we aren't making nullable columns non-nullable and that all
// the referenced columns actually exist.
let mut new_schema = self.schema().clone();

for alteration in alterations {
let field = self.schema().field(&alteration.path).ok_or_else(|| {
Error::invalid_input(
format!("Column {} does not exist in the dataset", alteration.path),
location!(),
)
})?;
if let Some(nullable) = alteration.nullable {
// TODO: in the future, we could check the values of the column to see if
// they are all non-null and thus the column could be made non-nullable.
if field.nullable && !nullable {
return Err(Error::invalid_input(
format!(
"Column {} is already nullable and thus cannot be made non-nullable",
alteration.path
),
location!(),
));
}
}

let field_mut = new_schema.mut_field_by_id(field.id).unwrap();
if let Some(rename) = &alteration.rename {
field_mut.name = rename.clone();
}
if let Some(nullable) = alteration.nullable {
field_mut.nullable = nullable;
}
}

new_schema.validate()?;

// If we aren't casting a column, we don't need to touch the fragments.
let transaction = Transaction::new(
self.manifest.version,
Operation::Project { schema: new_schema },
None,
);

let manifest = commit_transaction(
self,
&self.object_store,
self.commit_handler.as_ref(),
&transaction,
&Default::default(),
&Default::default(),
)
.await?;

self.manifest = Arc::new(manifest);

Ok(())
}

/// Remove columns from the dataset.
///
/// This is a metadata-only operation and does not remove the data from the
Expand Down Expand Up @@ -4208,4 +4304,105 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_rename_columns() -> Result<()> {
let metadata: HashMap<String, String> = [("k1".into(), "v1".into())].into();

let schema = Arc::new(ArrowSchema::new_with_metadata(
vec![
Field::new("a", DataType::Int32, false),
Field::new(
"b",
DataType::Struct(ArrowFields::from(vec![Field::new(
"c",
DataType::Int32,
true,
)])),
true,
),
],
metadata.clone(),
));

let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(StructArray::from(vec![(
Arc::new(ArrowField::new("c", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef,
)])),
],
)?;

let test_dir = tempdir()?;
let test_uri = test_dir.path().to_str().unwrap();

let batches = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let mut dataset = Dataset::write(batches, test_uri, None).await?;

let original_fragments = dataset.fragments().to_vec();

// Rename a top-level column
dataset
.alter_columns(&[ColumnAlteration::new("a".into())
.rename("x".into())
.set_nullable(true)])
.await?;
dataset.validate().await?;
assert_eq!(dataset.manifest.version, 2);
assert_eq!(dataset.fragments().as_ref(), &original_fragments);

let expected_schema = ArrowSchema::new_with_metadata(
vec![
Field::new("x", DataType::Int32, true),
Field::new(
"b",
DataType::Struct(ArrowFields::from(vec![Field::new(
"c",
DataType::Int32,
true,
)])),
true,
),
],
metadata.clone(),
);
assert_eq!(&ArrowSchema::from(dataset.schema()), &expected_schema);

// Rename to duplicate name fails
let err = dataset
.alter_columns(&[ColumnAlteration::new("b".into()).rename("x".into())])
.await
.unwrap_err();
assert!(err.to_string().contains("Duplicate field name \"x\""));

// Rename a nested column.
dataset
.alter_columns(&[ColumnAlteration::new("b.c".into()).rename("d".into())])
.await?;
dataset.validate().await?;
assert_eq!(dataset.manifest.version, 3);
assert_eq!(dataset.fragments().as_ref(), &original_fragments);

let expected_schema = ArrowSchema::new_with_metadata(
vec![
Field::new("x", DataType::Int32, true),
Field::new(
"b",
DataType::Struct(ArrowFields::from(vec![Field::new(
"d",
DataType::Int32,
true,
)])),
true,
),
],
metadata.clone(),
);
assert_eq!(&ArrowSchema::from(dataset.schema()), &expected_schema);

Ok(())
}
}

0 comments on commit 33bf813

Please sign in to comment.