Skip to content

Commit

Permalink
feat: add "merge insert" operation based on merge operation in other …
Browse files Browse the repository at this point in the history
…databases (#1647)

The "merge insert" operation can insert new rows, delete old rows, and
update old rows, all in a single transaction. It is a generic operation
that is used to provide upsert, find-or-create, and "replace range".

closes #1456

---------

Co-authored-by: Will Jones <[email protected]>
  • Loading branch information
westonpace and wjones127 authored Jan 30, 2024
1 parent b3db3cc commit 2cd296f
Show file tree
Hide file tree
Showing 20 changed files with 1,436 additions and 113 deletions.
2 changes: 2 additions & 0 deletions python/python/lance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
LanceDataset,
LanceOperation,
LanceScanner,
MergeInsertBuilder,
__version__,
write_dataset,
)
Expand All @@ -41,6 +42,7 @@
"LanceDataset",
"LanceOperation",
"LanceScanner",
"MergeInsertBuilder",
"__version__",
"write_dataset",
"schema_to_json",
Expand Down
21 changes: 20 additions & 1 deletion python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,14 @@
from .dependencies import numpy as np
from .dependencies import pandas as pd
from .fragment import FragmentMetadata, LanceFragment
from .lance import CleanupStats, _Dataset, _Operation, _Scanner, _write_dataset
from .lance import (
CleanupStats,
_Dataset,
_MergeInsertBuilder,
_Operation,
_Scanner,
_write_dataset,
)
from .lance import CompactionMetrics as CompactionMetrics
from .lance import __version__ as __version__
from .optimize import Compaction
Expand Down Expand Up @@ -80,6 +87,12 @@
]


class MergeInsertBuilder(_MergeInsertBuilder):
def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None):
reader = _coerce_reader(data_obj, schema)
super(MergeInsertBuilder, self).execute(reader)


class LanceDataset(pa.dataset.Dataset):
"""A dataset in Lance format where the data is stored at the given uri."""

Expand Down Expand Up @@ -630,6 +643,12 @@ def delete(self, predicate: Union[str, pa.compute.Expression]):
predicate = str(predicate)
self._ds.delete(predicate)

def merge_insert(
self,
on: Union[str, Iterable[str]],
):
return MergeInsertBuilder(self._ds, on)

def update(
self,
updates: Dict[str, str],
Expand Down
187 changes: 187 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,193 @@ def test_delete_data(tmp_path: Path):
assert dataset.count_rows() == 0


def test_merge_insert(tmp_path: Path):
nrows = 1000
table = pa.Table.from_pydict({"a": range(nrows), "b": [1 for _ in range(nrows)]})
dataset = lance.write_dataset(
table, tmp_path / "dataset", mode="create", max_rows_per_file=100
)
version = dataset.version

new_table = pa.Table.from_pydict(
{"a": range(300, 300 + nrows), "b": [2 for _ in range(nrows)]}
)

is_new = pc.field("b") == 2

dataset.merge_insert("a").when_not_matched_insert_all().execute(new_table)
table = dataset.to_table()
assert table.num_rows == 1300
assert table.filter(is_new).num_rows == 300

dataset = lance.dataset(tmp_path / "dataset", version=version)
dataset.restore()
dataset.merge_insert("a").when_matched_update_all().execute(new_table)
table = dataset.to_table()
assert table.num_rows == 1000
assert table.filter(is_new).num_rows == 700

dataset = lance.dataset(tmp_path / "dataset", version=version)
dataset.restore()
dataset.merge_insert(
"a"
).when_not_matched_insert_all().when_matched_update_all().execute(new_table)
table = dataset.to_table()
assert table.num_rows == 1300
assert table.filter(is_new).num_rows == 1000

dataset = lance.dataset(tmp_path / "dataset", version=version)
dataset.restore()
dataset.merge_insert("a").when_not_matched_by_source_delete().execute(new_table)
table = dataset.to_table()
assert table.num_rows == 700
assert table.filter(is_new).num_rows == 0

dataset = lance.dataset(tmp_path / "dataset", version=version)
dataset.restore()
dataset.merge_insert("a").when_not_matched_by_source_delete(
"a < 100"
).when_not_matched_insert_all().execute(new_table)

table = dataset.to_table()
assert table.num_rows == 1200
assert table.filter(is_new).num_rows == 300

# If the user doesn't specify anything then the merge_insert is
# a no-op and the operation fails
dataset = lance.dataset(tmp_path / "dataset", version=version)
dataset.restore()
with pytest.raises(ValueError):
dataset.merge_insert("a").execute(new_table)


def test_merge_insert_source_is_dataset(tmp_path: Path):
nrows = 1000
table = pa.Table.from_pydict({"a": range(nrows), "b": [1 for _ in range(nrows)]})
dataset = lance.write_dataset(
table, tmp_path / "dataset", mode="create", max_rows_per_file=100
)
version = dataset.version

new_table = pa.Table.from_pydict(
{"a": range(300, 300 + nrows), "b": [2 for _ in range(nrows)]}
)
new_dataset = lance.write_dataset(
new_table, tmp_path / "dataset2", mode="create", max_rows_per_file=80
)

is_new = pc.field("b") == 2

dataset.merge_insert("a").when_not_matched_insert_all().execute(new_dataset)
table = dataset.to_table()
assert table.num_rows == 1300
assert table.filter(is_new).num_rows == 300

dataset = lance.dataset(tmp_path / "dataset", version=version)
dataset.restore()

reader = new_dataset.to_batches()

dataset.merge_insert("a").when_not_matched_insert_all().execute(
reader, schema=new_dataset.schema
)
table = dataset.to_table()
assert table.num_rows == 1300
assert table.filter(is_new).num_rows == 300


def test_merge_insert_multiple_keys(tmp_path: Path):
nrows = 1000
# a - [0, 1, 2, ..., 999]
# b - [1, 1, 1, ..., 1]
# c - [0, 1, 0, ..., 1]
table = pa.Table.from_pydict(
{
"a": range(nrows),
"b": [1 for _ in range(nrows)],
"c": [i % 2 for i in range(nrows)],
}
)
dataset = lance.write_dataset(
table, tmp_path / "dataset", mode="create", max_rows_per_file=100
)

# a - [300, 301, 302, ..., 1299]
# b - [2, 2, 2, ..., 2]
# c - [0, 0, 0, ..., 0]
new_table = pa.Table.from_pydict(
{
"a": range(300, 300 + nrows),
"b": [2 for _ in range(nrows)],
"c": [0 for _ in range(nrows)],
}
)

is_new = pc.field("b") == 2

dataset.merge_insert(["a", "c"]).when_matched_update_all().execute(new_table)
table = dataset.to_table()
assert table.num_rows == 1000
assert table.filter(is_new).num_rows == 350


def test_merge_insert_incompatible_schema(tmp_path: Path):
nrows = 1000
table = pa.Table.from_pydict(
{
"a": range(nrows),
"b": [1 for _ in range(nrows)],
}
)
dataset = lance.write_dataset(
table, tmp_path / "dataset", mode="create", max_rows_per_file=100
)

new_table = pa.Table.from_pydict(
{
"a": range(300, 300 + nrows),
}
)

with pytest.raises(OSError):
dataset.merge_insert("a").when_matched_update_all().execute(new_table)


def test_merge_insert_vector_column(tmp_path: Path):
table = pa.Table.from_pydict(
{
"vec": pa.array([[1, 2, 3], [4, 5, 6]], pa.list_(pa.float32(), 3)),
"key": [1, 2],
}
)

new_table = pa.Table.from_pydict(
{
"vec": pa.array([[7, 8, 9], [10, 11, 12]], pa.list_(pa.float32(), 3)),
"key": [2, 3],
}
)

dataset = lance.write_dataset(
table, tmp_path / "dataset", mode="create", max_rows_per_file=100
)

dataset.merge_insert(
["key"]
).when_not_matched_insert_all().when_matched_update_all().execute(new_table)

expected = pa.Table.from_pydict(
{
"vec": pa.array(
[[1, 2, 3], [7, 8, 9], [10, 11, 12]], pa.list_(pa.float32(), 3)
),
"key": [1, 2, 3],
}
)

assert dataset.to_table().sort_by("key") == expected


def test_update_dataset(tmp_path: Path):
nrows = 100
vecs = pa.FixedSizeListArray.from_arrays(
Expand Down
3 changes: 2 additions & 1 deletion python/python/tests/test_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def test_dataset_progress(tmp_path: Path):

assert fragment == FragmentMetadata.from_json(json.dumps(metadata))

p = multiprocessing.Process(target=failing_write, args=(progress_uri, dataset_uri))
ctx = multiprocessing.get_context("spawn")
p = ctx.Process(target=failing_write, args=(progress_uri, dataset_uri))
p.start()
try:
p.join()
Expand Down
104 changes: 100 additions & 4 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ use chrono::Duration;

use futures::{StreamExt, TryFutureExt};
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::UpdateBuilder;
use lance::dataset::{
fragment::FileFragment as LanceFileFragment, progress::WriteFragmentProgress,
scanner::Scanner as LanceScanner, transaction::Operation as LanceOperation,
Dataset as LanceDataset, ReadParams, Version, WriteMode, WriteParams,
Dataset as LanceDataset, MergeInsertBuilder as LanceMergeInsertBuilder, ReadParams,
UpdateBuilder, Version, WhenMatched, WhenNotMatched, WhenNotMatchedBySource, WriteMode,
WriteParams,
};
use lance::index::{
scalar::ScalarIndexParams,
Expand All @@ -47,9 +48,9 @@ use lance_linalg::distance::MetricType;
use lance_table::format::Fragment;
use lance_table::io::commit::CommitHandler;
use object_store::path::Path;
use pyo3::exceptions::PyStopIteration;
use pyo3::exceptions::{PyStopIteration, PyTypeError};
use pyo3::prelude::*;
use pyo3::types::{PyList, PySet};
use pyo3::types::{PyList, PySet, PyString};
use pyo3::{
exceptions::{PyIOError, PyKeyError, PyValueError},
pyclass,
Expand Down Expand Up @@ -93,6 +94,101 @@ fn convert_schema(arrow_schema: &ArrowSchema) -> PyResult<Schema> {
})
}

#[pyclass(name = "_MergeInsertBuilder", module = "_lib", subclass)]
pub struct MergeInsertBuilder {
builder: LanceMergeInsertBuilder,
dataset: Py<Dataset>,
}

#[pymethods]
impl MergeInsertBuilder {
#[new]
pub fn new(dataset: &PyAny, on: &PyAny) -> PyResult<Self> {
let dataset: Py<Dataset> = dataset.extract()?;
let ds = dataset.borrow(on.py()).ds.clone();
// Either a single string, which we put in a vector or an iterator
// of strings, which we collect into a vector
let on = PyAny::downcast::<PyString>(on)
.map(|val| vec![val.to_string()])
.or_else(|_| {
let iterator = on.iter().map_err(|_| {
PyTypeError::new_err(
"The `on` argument to merge_insert must be a str or iterable of str",
)
})?;
let mut keys = Vec::new();
for key in iterator {
keys.push(PyAny::downcast::<PyString>(key?)?.to_string());
}
PyResult::Ok(keys)
})?;

let mut builder = LanceMergeInsertBuilder::try_new(ds, on)
.map_err(|err| PyValueError::new_err(err.to_string()))?;

// We don't have do_nothing methods in python so we start with a blank slate
builder
.when_matched(WhenMatched::DoNothing)
.when_not_matched(WhenNotMatched::DoNothing);

Ok(Self { builder, dataset })
}

pub fn when_matched_update_all(mut slf: PyRefMut<Self>) -> PyResult<PyRefMut<Self>> {
slf.builder.when_matched(WhenMatched::UpdateAll);
Ok(slf)
}

pub fn when_not_matched_insert_all(mut slf: PyRefMut<Self>) -> PyResult<PyRefMut<Self>> {
slf.builder.when_not_matched(WhenNotMatched::InsertAll);
Ok(slf)
}

pub fn when_not_matched_by_source_delete<'a>(
mut slf: PyRefMut<'a, Self>,
expr: Option<&str>,
) -> PyResult<PyRefMut<'a, Self>> {
let new_val = if let Some(expr) = expr {
let dataset = slf.dataset.borrow(slf.py());
WhenNotMatchedBySource::delete_if(&dataset.ds, expr)
.map_err(|err| PyValueError::new_err(err.to_string()))?
} else {
WhenNotMatchedBySource::Delete
};
slf.builder.when_not_matched_by_source(new_val);
Ok(slf)
}

pub fn execute(&mut self, new_data: &PyAny) -> PyResult<()> {
let py = new_data.py();

let new_data: Box<dyn RecordBatchReader + Send> = if new_data.is_instance_of::<Scanner>() {
let scanner: Scanner = new_data.extract()?;
Box::new(
RT.spawn(Some(py), async move { scanner.to_reader().await })?
.map_err(|err| PyValueError::new_err(err.to_string()))?,
)
} else {
Box::new(ArrowArrayStreamReader::from_pyarrow(new_data)?)
};

let job = self
.builder
.try_build()
.map_err(|err| PyValueError::new_err(err.to_string()))?;

let new_self = RT
.spawn(Some(py), job.execute_reader(new_data))?
.map_err(|err| PyIOError::new_err(err.to_string()))?;

let dataset = self.dataset.as_ref(py);

dataset.borrow_mut().ds = new_self;

Ok(())
}
}

#[pymethods]
impl Operation {
fn __repr__(&self) -> String {
Expand Down
Loading

0 comments on commit 2cd296f

Please sign in to comment.