Skip to content

Commit

Permalink
fix: allow merge() with dataset as input (#1869)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 authored Jan 31, 2024
1 parent e709e7d commit 5807cde
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 4 deletions.
18 changes: 18 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,24 @@ def test_merge_data(tmp_path: Path):
assert dataset.to_table() == expected


def test_merge_from_dataset(tmp_path: Path):
tab1 = pa.table({"a": range(100), "b": range(100)})
ds1 = lance.write_dataset(tab1, tmp_path / "dataset1", mode="append")

tab2 = pa.table({"a": range(100), "c": range(100)})
ds2 = lance.write_dataset(tab2, tmp_path / "dataset2", mode="append")

ds1.merge(ds2.to_batches(), "a", schema=ds2.schema)
assert ds1.version == 2
assert ds1.to_table() == pa.table(
{
"a": range(100),
"b": range(100),
"c": range(100),
}
)


def test_delete_data(tmp_path: Path):
# We pass schema explicitly since we want b to be non-nullable.
schema = pa.schema(
Expand Down
12 changes: 9 additions & 3 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,11 +631,17 @@ impl Dataset {
fn merge(
&mut self,
reader: PyArrowType<ArrowArrayStreamReader>,
left_on: &str,
right_on: &str,
left_on: String,
right_on: String,
) -> PyResult<()> {
let mut new_self = self.ds.as_ref().clone();
RT.block_on(None, new_self.merge(reader.0, left_on, right_on))?
let new_self = RT
.spawn(None, async move {
new_self
.merge(reader.0, &left_on, &right_on)
.await
.map(|_| new_self)
})?
.map_err(|err| PyIOError::new_err(err.to_string()))?;
self.ds = Arc::new(new_self);
Ok(())
Expand Down
7 changes: 7 additions & 0 deletions python/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ impl BackgroundExecutor {
}

/// Spawn a task and wait for it to complete.
///
/// This method is safe to use with inputs that may reference a Rust async
/// runtime.
pub fn spawn<T>(&self, py: Option<Python<'_>>, task: T) -> PyResult<T::Output>
where
T: Future + Send + 'static,
Expand Down Expand Up @@ -119,6 +122,10 @@ impl BackgroundExecutor {
/// Block on a future and wait for it to complete.
///
/// This helper method also frees the GIL before blocking.
///
/// This method is NOT safe to use with inputs that may reference a Rust async
/// runtime. If the future references an async runtime, it will panic on an
/// error: "Cannot start a runtime from within a runtime."
pub fn block_on<F: Future + Send>(
&self,
py: Option<Python<'_>>,
Expand Down
7 changes: 6 additions & 1 deletion rust/lance/src/dataset/hash_joiner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ impl HashJoiner {
schema.field_with_name(on)?;

// Hold all data in memory for simple implementation. Can do external sort later.
let batches = reader.collect::<std::result::Result<Vec<RecordBatch>, _>>()?;
// This is a blocking operation, so we'll run it in a separate thread.
let batches = tokio::task::spawn_blocking(|| {
reader.collect::<std::result::Result<Vec<RecordBatch>, _>>()
})
.await
.unwrap()?;
if batches.is_empty() {
return Err(Error::IO {
message: "HashJoiner: No data".to_string(),
Expand Down

0 comments on commit 5807cde

Please sign in to comment.