From 5d898a7cc06dd52a5f6369a8cebbf3ebe70b0e6e Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 2 Feb 2024 19:23:15 -0800 Subject: [PATCH] feat(python): expose `drop_columns()` in Python (#1904) Renames the Rust method `drop()` to `drop_columns()` for clarity and also alignment with `add_columns()` and `alter_columns()`. Closes #1076 Related #1674 --- .pre-commit-config.yaml | 6 +- python/Makefile | 8 +- python/python/lance/dataset.py | 31 ++ python/python/tests/test_schema_evolution.py | 50 +++ python/src/dataset.rs | 13 + rust/lance/src/dataset.rs | 399 +++++-------------- rust/lance/src/dataset/transaction.rs | 17 + 7 files changed, 217 insertions(+), 307 deletions(-) create mode 100644 python/python/tests/test_schema_evolution.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8849b20c2f..c600e4aaca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,8 +1,8 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.5 + rev: v0.2.0 hooks: - id: ruff - args: [--fix, --exit-non-zero-on-fix] + args: [--preview, --fix, --exit-non-zero-on-fix] - id: ruff-format - args: [] + args: [--preview] diff --git a/python/Makefile b/python/Makefile index 68083dfbf5..a2c9df969f 100644 --- a/python/Makefile +++ b/python/Makefile @@ -17,16 +17,16 @@ format: format-python .PHONY: format format-python: - ruff format python - ruff --fix python + ruff format --preview python + ruff --preview --fix python .PHONY: format-python lint: lint-python lint-rust .PHONY: lint lint-python: - ruff format --check python - ruff python + ruff format --preview --check python + ruff --preview python .PHONY: lint-python lint-rust: diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index eccc116100..1b80dc696a 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -661,6 +661,37 @@ def merge( self._ds.merge(reader, left_on, right_on) + def drop_columns(self, columns: List[str]): + """Drop one or more columns from the dataset + + Parameters + ---------- + columns : list of str + The names of the columns to drop. These can be nested column references + (e.g. "a.b.c") or top-level column names (e.g. "a"). + + This is a metadata-only operation and does not remove the data from the + underlying storage. In order to remove the data, you must subsequently + call ``compact_files`` to rewrite the data without the removed columns and + then call ``cleanup_files`` to remove the old files. + + Examples + -------- + >>> import lance + >>> import pyarrow as pa + >>> table = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + >>> dataset = lance.write_dataset(table, "example") + >>> dataset.drop_columns(["a"]) + >>> dataset.to_table().to_pandas() + b + 0 a + 1 b + 2 c + """ + self._ds.drop_columns(columns) + # Indices might have changed + self._list_indices_res = None + def delete(self, predicate: Union[str, pa.compute.Expression]): """ Delete rows from the dataset. diff --git a/python/python/tests/test_schema_evolution.py b/python/python/tests/test_schema_evolution.py new file mode 100644 index 0000000000..5837e84cba --- /dev/null +++ b/python/python/tests/test_schema_evolution.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024. Lance Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import lance +import pyarrow as pa +import pyarrow.compute as pc +import pytest + + +def test_drop_columns(tmp_path: Path): + dims = 32 + nrows = 512 + values = pc.random(nrows * dims).cast("float32") + table = pa.table({ + "a": pa.FixedSizeListArray.from_arrays(values, dims), + "b": range(nrows), + "c": range(nrows), + }) + dataset = lance.write_dataset(table, tmp_path) + dataset.create_index("a", "IVF_PQ", num_partitions=2, num_sub_vectors=1) + + # Drop a column, index is kept + dataset.drop_columns(["b"]) + assert dataset.schema == pa.schema({ + "a": pa.list_(pa.float32(), dims), + "c": pa.int64(), + }) + assert len(dataset.list_indices()) == 1 + + # Drop vector column, index is dropped + dataset.drop_columns(["a"]) + assert dataset.schema == pa.schema({"c": pa.int64()}) + assert len(dataset.list_indices()) == 0 + + # Can't drop all columns + with pytest.raises(ValueError): + dataset.drop_columns(["c"]) diff --git a/python/src/dataset.rs b/python/src/dataset.rs index a80e96629b..66e27b4b0a 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1003,6 +1003,19 @@ impl Dataset { RT.block_on(None, self.ds.validate())? .map_err(|err| PyIOError::new_err(err.to_string())) } + + fn drop_columns(&mut self, columns: Vec<&str>) -> PyResult<()> { + let mut new_self = self.ds.as_ref().clone(); + RT.block_on(None, new_self.drop_columns(&columns))? + .map_err(|err| match err { + lance::Error::InvalidInput { source, .. } => { + PyValueError::new_err(source.to_string()) + } + _ => PyIOError::new_err(err.to_string()), + })?; + self.ds = Arc::new(new_self); + Ok(()) + } } impl Dataset { diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 74a0c83bb3..29b1cdd1b5 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -897,39 +897,9 @@ impl Dataset { /// dataset. /// Parameters: /// - `columns`: the list of column names to drop. + #[deprecated(since = "0.9.12", note = "Please use `drop_columns` instead.")] pub async fn drop(&mut self, columns: &[&str]) -> Result<()> { - // Check if columns are present in the dataset and construct the new schema. - for col in columns { - if self.schema().field(col).is_none() { - return Err(Error::invalid_input( - format!("Column {} does not exist in the dataset", col), - location!(), - )); - } - } - - let columns_to_remove = self.manifest.schema.project(columns)?; - let new_schema = self.manifest.schema.exclude(columns_to_remove)?; - - let transaction = Transaction::new( - self.manifest.version, - Operation::Project { schema: new_schema }, - None, - ); - - let manifest = commit_transaction( - self, - &self.object_store, - self.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - ) - .await?; - - self.manifest = Arc::new(manifest); - - Ok(()) + self.drop_columns(columns).await } /// Create a Scanner to scan the dataset. @@ -1466,6 +1436,56 @@ impl Dataset { } } +impl Dataset { + /// Remove columns from the dataset. + /// + /// This is a metadata-only operation and does not remove the data from the + /// underlying storage. In order to remove the data, you must subsequently + /// call `compact_files` to rewrite the data without the removed columns and + /// then call `cleanup_files` to remove the old files. + pub async fn drop_columns(&mut self, columns: &[&str]) -> Result<()> { + // Check if columns are present in the dataset and construct the new schema. + for col in columns { + if self.schema().field(col).is_none() { + return Err(Error::invalid_input( + format!("Column {} does not exist in the dataset", col), + location!(), + )); + } + } + + let columns_to_remove = self.manifest.schema.project(columns)?; + let new_schema = self.manifest.schema.exclude(columns_to_remove)?; + + if new_schema.fields.is_empty() { + return Err(Error::invalid_input( + "Cannot drop all columns from a dataset", + location!(), + )); + } + + let transaction = Transaction::new( + self.manifest.version, + Operation::Project { schema: new_schema }, + None, + ); + + let manifest = commit_transaction( + self, + &self.object_store, + self.commit_handler.as_ref(), + &transaction, + &Default::default(), + &Default::default(), + ) + .await?; + + self.manifest = Arc::new(manifest); + + Ok(()) + } +} + #[derive(Debug)] pub(crate) struct ManifestWriteConfig { auto_set_feature_flags: bool, // default true @@ -1574,7 +1594,7 @@ mod tests { cast::{as_string_array, as_struct_array}, types::Int32Type, ArrayRef, DictionaryArray, Float32Array, Int32Array, Int64Array, Int8Array, - Int8DictionaryArray, ListArray, RecordBatch, RecordBatchIterator, StringArray, UInt16Array, + Int8DictionaryArray, RecordBatch, RecordBatchIterator, StringArray, UInt16Array, UInt32Array, }; use arrow_ord::sort::sort_to_indices; @@ -2628,9 +2648,8 @@ mod tests { } #[tokio::test] - async fn test_drop() { - let mut metadata: HashMap = HashMap::new(); - metadata.insert(String::from("k1"), String::from("v1")); + async fn test_drop_columns() -> Result<()> { + let metadata: HashMap = [("k1".into(), "v1".into())].into(); let schema = Arc::new(ArrowSchema::new_with_metadata( vec![ @@ -2638,23 +2657,8 @@ mod tests { Field::new( "s", DataType::Struct(ArrowFields::from(vec![ - Field::new( - "d", - DataType::Dictionary( - Box::new(DataType::UInt32), - Box::new(DataType::Utf8), - ), - true, - ), - ArrowField::new( - "l", - DataType::List(Arc::new(ArrowField::new( - "item", - DataType::Int32, - true, - ))), - true, - ), + Field::new("d", DataType::Int32, true), + Field::new("l", DataType::Int32, true), ])), true, ), @@ -2663,270 +2667,65 @@ mod tests { metadata.clone(), )); - let struct_array_1 = Arc::new(StructArray::from(vec![ - ( - Arc::new(ArrowField::new( - "d", - DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), - true, - )), - Arc::new( - DictionaryArray::try_new( - UInt32Array::from(vec![1, 0]), - Arc::new(StringArray::from(vec!["A", "C", "G", "T"])), - ) - .unwrap(), - ) as ArrayRef, - ), - ( - Arc::new(ArrowField::new( - "l", - DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))), - true, - )), - Arc::new(ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1i32), Some(2), Some(3)]), - Some(vec![Some(4), Some(5)]), - ])), - ), - ])); - let struct_array_2 = Arc::new(StructArray::from(vec![ - ( - Arc::new(ArrowField::new( - "d", - DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), - true, - )), - Arc::new( - DictionaryArray::try_new( - UInt32Array::from(vec![2, 1]), - Arc::new(StringArray::from(vec!["A", "C", "G", "T"])), - ) - .unwrap(), - ) as ArrayRef, - ), - ( - Arc::new(ArrowField::new( - "l", - DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))), - true, - )), - Arc::new(ListArray::from_iter_primitive::(vec![ - Some(vec![Some(4), Some(5)]), - Some((0..2_000).map(Some).collect::>()), - ])), - ), - ])); - - let struct_array_full_1 = Arc::new(StructArray::from(vec![ - ( - Arc::new(ArrowField::new( - "d", - DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), - true, - )), - Arc::new( - DictionaryArray::try_new( - UInt32Array::from(vec![1, 0, 2, 1]), - Arc::new(StringArray::from(vec!["A", "C", "G", "T"])), - ) - .unwrap(), - ) as ArrayRef, - ), - ( - Arc::new(ArrowField::new( - "l", - DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))), - true, - )), - Arc::new(ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1i32), Some(2), Some(3)]), - Some(vec![Some(4), Some(5)]), - Some(vec![Some(4), Some(5)]), - Some((0..2_000).map(Some).collect::>()), - ])), - ), - ])); - - let struct_array_full_2 = Arc::new(StructArray::from(vec![( - Arc::new(ArrowField::new( - "l", - DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))), - true, - )), - Arc::new(ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1i32), Some(2), Some(3)]), - Some(vec![Some(4), Some(5)]), - Some(vec![Some(4), Some(5)]), - Some((0..2_000).map(Some).collect::>()), - ])) as ArrayRef, - )])); - - let batch1 = RecordBatch::try_new( + let batch = RecordBatch::try_new( schema.clone(), vec![ Arc::new(Int32Array::from(vec![1, 2])), - struct_array_1.clone(), + Arc::new(StructArray::from(vec![ + ( + Arc::new(ArrowField::new("d", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef, + ), + ( + Arc::new(ArrowField::new("l", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2])), + ), + ])), Arc::new(Float32Array::from(vec![1.0, 2.0])), ], - ) - .unwrap(); - let batch2 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![3, 2])), - struct_array_2.clone(), - Arc::new(Float32Array::from(vec![3.0, 4.0])), - ], - ) - .unwrap(); + )?; - let test_dir = tempdir().unwrap(); + let test_dir = tempdir()?; let test_uri = test_dir.path().to_str().unwrap(); - let write_params = WriteParams { - mode: WriteMode::Append, - ..Default::default() - }; - - let batches = - RecordBatchIterator::new(vec![batch1.clone()].into_iter().map(Ok), schema.clone()); - Dataset::write(batches, test_uri, Some(write_params.clone())) - .await - .unwrap(); - - let batches = - RecordBatchIterator::new(vec![batch2.clone()].into_iter().map(Ok), schema.clone()); - Dataset::write(batches, test_uri, Some(write_params.clone())) - .await - .unwrap(); - - let expected_drop_x = RecordBatch::try_new( - Arc::new(ArrowSchema::new_with_metadata( - vec![ - Field::new("i", DataType::Int32, false), - schema.fields[1].as_ref().clone(), - ], - metadata.clone(), - )), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3, 2])), - struct_array_full_1.clone(), - ], - ) - .unwrap(); - - let expected_drop_s_d = RecordBatch::try_new( - Arc::new(ArrowSchema::new_with_metadata( - vec![ - Field::new("i", DataType::Int32, false), - Field::new( - "s", - DataType::Struct(ArrowFields::from(vec![ArrowField::new( - "l", - DataType::List(Arc::new(ArrowField::new( - "item", - DataType::Int32, - true, - ))), - true, - )])), - true, - ), - Field::new("x", DataType::Float32, false), - ], - metadata.clone(), - )), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3, 2])), - struct_array_full_2.clone(), - Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0])), - ], - ) - .unwrap(); - - let dataset = Dataset::open(test_uri).await.unwrap(); - assert_eq!(dataset.fragments().len(), 2); - assert_eq!(dataset.manifest.max_fragment_id(), Some(1)); - - let mut dataset = Dataset::open(test_uri).await.unwrap(); - dataset.drop(&["x"]).await.unwrap(); - dataset.validate().await.unwrap(); - - assert_eq!(dataset.schema().fields.len(), 2); - assert_eq!(dataset.schema().metadata, metadata.clone()); - assert_eq!(dataset.version().version, 3); - assert_eq!(dataset.fragments().len(), 2); - assert_eq!(dataset.fragments()[0].files.len(), 1); - assert_eq!(dataset.fragments()[1].files.len(), 1); - assert_eq!(dataset.manifest.max_fragment_id(), Some(1)); + let batches = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let mut dataset = Dataset::write(batches, test_uri, None).await?; - let actual_batches = dataset - .scan() - .try_into_stream() - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - let actual = concat_batches(&actual_batches[0].schema(), &actual_batches).unwrap(); + let lance_schema = dataset.schema().clone(); + let original_fragments = dataset.fragments().to_vec(); - assert_eq!(actual, expected_drop_x); + dataset.drop_columns(&["x"]).await?; + dataset.validate().await?; - // Validate we can still read after re-instantiating dataset, which - // clears the cache. - let dataset = Dataset::open(test_uri).await.unwrap(); - let actual_batches = dataset - .scan() - .try_into_stream() - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - let actual = concat_batches(&actual_batches[0].schema(), &actual_batches).unwrap(); + let expected_schema = lance_schema.project(&["i", "s"])?; + assert_eq!(dataset.schema(), &expected_schema); - assert_eq!(actual, expected_drop_x); + assert_eq!(dataset.version().version, 2); + assert_eq!(dataset.fragments().as_ref(), &original_fragments); - let overwrite_params = WriteParams { - mode: WriteMode::Overwrite, - ..Default::default() - }; + dataset.drop_columns(&["s.d"]).await?; + dataset.validate().await?; - let batches = - RecordBatchIterator::new(vec![batch1.clone()].into_iter().map(Ok), schema.clone()); - Dataset::write(batches, test_uri, Some(overwrite_params.clone())) - .await - .unwrap(); - - let batches = - RecordBatchIterator::new(vec![batch2.clone()].into_iter().map(Ok), schema.clone()); - Dataset::write(batches, test_uri, Some(write_params.clone())) - .await - .unwrap(); + let expected_schema = expected_schema.project(&["i", "s.l"])?; + assert_eq!(dataset.schema(), &expected_schema); - let mut dataset = Dataset::open(test_uri).await.unwrap(); - dataset.drop(&["s.d"]).await.unwrap(); - dataset.validate().await.unwrap(); + let expected_data = RecordBatch::try_new( + Arc::new(ArrowSchema::from(&expected_schema)), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StructArray::from(vec![( + Arc::new(ArrowField::new("l", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef, + )])), + ], + )?; + let actual_data = dataset.scan().try_into_batch().await?; + assert_eq!(actual_data, expected_data); - assert_eq!(dataset.schema().fields.len(), 3); - assert_eq!(dataset.schema().metadata, metadata.clone()); - assert_eq!(dataset.version().version, 6); - assert_eq!(dataset.fragments().len(), 2); - assert_eq!(dataset.fragments()[0].files.len(), 1); - assert_eq!(dataset.fragments()[1].files.len(), 1); - assert_eq!(dataset.manifest.max_fragment_id(), Some(2)); + assert_eq!(dataset.version().version, 3); + assert_eq!(dataset.fragments().as_ref(), &original_fragments); - let actual_batches = dataset - .scan() - .try_into_stream() - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - let actual = concat_batches(&actual_batches[0].schema(), &actual_batches).unwrap(); - assert_eq!(actual, expected_drop_s_d); + Ok(()) } #[tokio::test] diff --git a/rust/lance/src/dataset/transaction.rs b/rust/lance/src/dataset/transaction.rs index 75480577c1..ffd7c5f12a 100644 --- a/rust/lance/src/dataset/transaction.rs +++ b/rust/lance/src/dataset/transaction.rs @@ -492,6 +492,10 @@ impl Transaction { } Operation::Project { .. } => { final_fragments.extend(maybe_existing_fragments?.clone()); + + // Some fields that have indices may have been removed, so we should + // remove those indices as well. + Self::retain_relevant_indices(&mut final_indices, &schema) } Operation::Restore { .. } => { unreachable!() @@ -522,6 +526,19 @@ impl Transaction { Ok((manifest, final_indices)) } + fn retain_relevant_indices(indices: &mut Vec, schema: &Schema) { + let field_ids = schema + .fields_pre_order() + .map(|f| f.id) + .collect::>(); + indices.retain(|existing_index| { + existing_index + .fields + .iter() + .all(|field_id| field_ids.contains(field_id)) + }); + } + fn recalculate_fragment_bitmap( old: &RoaringBitmap, groups: &[RewriteGroup],