From 085b4d97f742334477fb3b3c9357e1795479200d Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 30 Jan 2024 13:57:07 -0800 Subject: [PATCH] fix: shard batch iterator can reads partial batches (#1889) --- .../lance/_dataset/sharded_batch_iterator.py | 2 +- python/python/tests/test_dataset.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/python/python/lance/_dataset/sharded_batch_iterator.py b/python/python/lance/_dataset/sharded_batch_iterator.py index 8b1abafd88..f405d4985d 100644 --- a/python/python/lance/_dataset/sharded_batch_iterator.py +++ b/python/python/lance/_dataset/sharded_batch_iterator.py @@ -141,7 +141,7 @@ def _gen_ranges(): total, self._world_size * self._batch_size, ): - yield start, start + self._batch_size + yield start, min(start + self._batch_size, total) return self._ds._ds.take_scan( _gen_ranges(), diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index dea752c303..3eaefbe71c 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -1416,3 +1416,22 @@ def test_sharded_iterator_batches(tmp_path: Path): for j in range(i, i + BATCH_SIZE) ] ) + + +def test_sharded_iterator_non_full_batch(tmp_path: Path): + arr = pa.array(range(1186)) + tbl = pa.table({"a": arr}) + + ds = lance.write_dataset(tbl, tmp_path) + shard_datast = ShardedBatchIterator( + ds, + 1, + 2, + columns=["a"], + batch_size=100, + granularity="batch", + ) + batches = pa.concat_arrays([b["a"] for b in shard_datast]) + + # Can read partial batches + assert len(set(range(1100, 1186)) - set(batches.to_pylist())) == 0