Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Update Dataset.count() to avoid unnecessarily keeping BlockRefs in-memory #46369

Merged
merged 21 commits into from
Jul 10, 2024
Merged
38 changes: 30 additions & 8 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Dict,
Generic,
Iterable,
Iterator,
List,
Literal,
Mapping,
Expand Down Expand Up @@ -2486,11 +2487,12 @@ def count(self) -> int:

get_num_rows = cached_remote_fn(_get_num_rows)

return sum(
ray.get(
[get_num_rows.remote(block) for block in self.get_internal_block_refs()]
)
)
# Directly loop over the iterator of `BlockRef`s instead of first
# retrieving a list of `BlockRef`s.
total_rows = 0
for block_ref in self.iter_internal_block_refs():
total_rows += ray.get(get_num_rows.remote(block_ref))
return total_rows

@ConsumptionAPI(
if_more_than_read=True,
Expand Down Expand Up @@ -4563,7 +4565,29 @@ def stats(self) -> str:
def _get_stats_summary(self) -> DatasetStatsSummary:
return self._plan.stats_summary()

@ConsumptionAPI(pattern="Time complexity:")
@ConsumptionAPI(pattern="")
@DeveloperAPI
def iter_internal_block_refs(self) -> Iterator[ObjectRef[Block]]:
"""Get an iterator over references to the underlying blocks of this Dataset.

This function can be used for zero-copy access to the data. It does not
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This function can be used for zero-copy access to the data. It does not
This function can be used for zero-copy access to the data. It doesn't

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep the data materialized in-memory.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does zero-copy access mean here? You might copy the data when you get the block reference, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i had thought that when we get the RefBundle / BlockRef, it does not copy the data. that's the main advantage of passing the references instead of blocks themselves, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. Yeah, if you don't call ray.get there won't be any copies, although the way I read this makes it sound like I can access the actual Block without copies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, let me just remove the line. i think saying "It does not keep the data materialized in-memory." is the more important main point to get across.


Examples:
>>> import ray
>>> ds = ray.data.range(1)
>>> for block_ref in ds.get_internal_block_refs():
... block = ray.get(block_ref)

Returns:
An iterator over references to this Dataset's blocks.
"""
iter_block_refs_md, _, _ = self._plan.execute_to_iterator()
iter_block_refs = (block_ref for block_ref, _ in iter_block_refs_md)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just realized that we already have block metadata here. So no need to submit additional tasks to count rows.
We can update this method to return Iterator[RefBundle]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

self._synchronize_progress_bar()
return iter_block_refs

@ConsumptionAPI(pattern="")
@DeveloperAPI
def get_internal_block_refs(self) -> List[ObjectRef[Block]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(can do this later) there are only a few use cases of get_internal_block_refs, we can also update them to use iter_internal_block_refs.

"""Get a list of references to the underlying blocks of this dataset.
Expand All @@ -4577,8 +4601,6 @@ def get_internal_block_refs(self) -> List[ObjectRef[Block]]:
>>> ds.get_internal_block_refs()
[ObjectRef(...)]

Time complexity: O(1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing this because it's no longer accurate.


Returns:
A list of references to this dataset's blocks.
"""
Expand Down
17 changes: 17 additions & 0 deletions python/ray/data/tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,23 @@ def test_get_internal_block_refs(ray_start_regular_shared):
assert out == list(range(10)), out


def test_iter_internal_block_refs(ray_start_regular_shared):
n = 10
iter_block_refs = ray.data.range(
n, override_num_blocks=n
).iter_internal_block_refs()

out = []
block_ref_count = 0
for block_ref in iter_block_refs:
b = ray.get(block_ref)
out.extend(extract_values("id", BlockAccessor.for_block(b).iter_rows(True)))
block_ref_count += 1
out = sorted(out)
assert block_ref_count == n
assert out == list(range(n)), out


def test_fsspec_filesystem(ray_start_regular_shared, tmp_path):
"""Same as `test_parquet_write` but using a custom, fsspec filesystem.

Expand Down
Loading