Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Update Dataset.count() to avoid unnecessarily keeping BlockRefs in-memory #46369

Merged
merged 21 commits into from
Jul 10, 2024
Merged
1 change: 1 addition & 0 deletions doc/source/data/api/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ Inspecting Metadata
Dataset.input_files
Dataset.stats
Dataset.get_internal_block_refs
Dataset.iter_internal_ref_bundles

Execution
---------
Expand Down
49 changes: 48 additions & 1 deletion python/ray/data/_internal/execution/legacy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@
It should be deleted once we fully move to the new executor backend.
"""

from typing import Iterator, Tuple
from typing import Iterator, Optional, Tuple

from ray.data._internal.block_list import BlockList
from ray.data._internal.execution.interfaces import (
Executor,
PhysicalOperator,
RefBundle,
)
from ray.data._internal.execution.interfaces.executor import OutputIterator
from ray.data._internal.logical.optimizers import get_execution_plan
from ray.data._internal.logical.util import record_operators_usage
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.stats import DatasetStats
from ray.data._internal.util import unify_block_metadata_schema
from ray.data.block import Block, BlockMetadata
from ray.types import ObjectRef

Expand Down Expand Up @@ -59,6 +61,51 @@ def execute_to_legacy_bundle_iterator(
dag = dag_rewrite(dag)

bundle_iter = executor.execute(dag, initial_stats=stats)

class CacheMetadataIterator(OutputIterator):
"""Wrapper for `bundle_iterator` above.

For a given iterator which yields output RefBundles,
collect the metadata from each output bundle, and yield the
original RefBundle. Only after the entire iterator is exhausted,
we cache the resulting metadata to the execution plan."""

def __init__(self, base_iterator: OutputIterator):
# Note: the base_iterator should be of type StreamIterator,
# defined within `StreamingExecutor.execute()`. It must
# support the `get_next()` method.
self._base_iterator = base_iterator
self._collected_metadata = BlockMetadata(
num_rows=0,
size_bytes=0,
schema=None,
input_files=None,
exec_stats=None,
)

def get_next(self, output_split_idx: Optional[int] = None) -> RefBundle:
try:
bundle = self._base_iterator.get_next(output_split_idx)
self._collect_metadata(bundle)
return bundle
except StopIteration:
# Once the iterator is completely exhausted, we are done
# collecting metadata. We can add this cached metadata to the plan.
plan._snapshot_metadata = self._collected_metadata
raise

def _collect_metadata(self, bundle: RefBundle) -> RefBundle:
"""Collect the metadata from each output bundle and accumulate
results, so we can access important information, such as
row count, schema, etc., after iteration completes."""
self._collected_metadata.num_rows += bundle.num_rows()
self._collected_metadata.size_bytes += bundle.size_bytes()
self._collected_metadata.schema = unify_block_metadata_schema(
[self._collected_metadata, *bundle.metadata]
)
return bundle

bundle_iter = CacheMetadataIterator(bundle_iter)
return bundle_iter


Expand Down
11 changes: 11 additions & 0 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def __init__(
self._snapshot_operator: Optional[LogicalOperator] = None
self._snapshot_stats = None
self._snapshot_bundle = None
# Snapshot of only metadata corresponding to the final operator's
# output bundles, used as the source of truth for the Dataset's schema
# and count. This is calculated and cached when the plan is executed as an
# iterator (`execute_to_iterator()`), and avoids caching
# all of the output blocks in memory like in `self.snapshot_bundle`.
self._snapshot_metadata: Optional[BlockMetadata] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any ideas off the top of my head, but I think it'd be good if we simplify the how we cache bundles and metadata at some point. Might be confusing how execute_to_iterator uses _snapshot_metadata but execute doesn't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, added a TODO comment


# Cached schema.
self._schema = None
Expand Down Expand Up @@ -148,6 +154,9 @@ def generate_logical_plan_string(
# This plan has executed some but not all operators.
schema = unify_block_metadata_schema(self._snapshot_bundle.metadata)
count = self._snapshot_bundle.num_rows()
elif self._snapshot_metadata is not None:
schema = self._snapshot_metadata.schema
count = self._snapshot_metadata.num_rows
else:
# This plan hasn't executed any operators.
sources = self._logical_plan.sources()
Expand Down Expand Up @@ -414,6 +423,8 @@ def execute_to_iterator(

metrics_tag = create_dataset_tag(self._dataset_name, self._dataset_uuid)
executor = StreamingExecutor(copy.deepcopy(ctx.execution_options), metrics_tag)
# TODO(scottjlee): replace with `execute_to_legacy_bundle_iterator` and
# update execute_to_iterator usages to handle RefBundles instead of Blocks
block_iter = execute_to_legacy_block_iterator(
executor,
self,
Expand Down
78 changes: 57 additions & 21 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Dict,
Generic,
Iterable,
Iterator,
List,
Literal,
Mapping,
Expand Down Expand Up @@ -77,6 +78,7 @@
VALID_BATCH_FORMATS,
Block,
BlockAccessor,
BlockMetadata,
DataBatch,
T,
U,
Expand Down Expand Up @@ -2459,12 +2461,14 @@ def show(self, limit: int = 20) -> None:
@ConsumptionAPI(
if_more_than_read=True,
datasource_metadata="row count",
pattern="Time complexity:",
pattern="Examples:",
)
def count(self) -> int:
"""Count the number of records in the dataset.
"""Count the number of rows in the dataset.

Time complexity: O(dataset size / parallelism), O(1) for parquet
For Datasets which only read Parquet files (created with
:meth:`~ray.data.read_parquet`), this method reads the file metadata to
efficiently count the number of rows without reading in the entire data.

Examples:
>>> import ray
Expand All @@ -2484,13 +2488,15 @@ def count(self) -> int:
if meta_count is not None:
return meta_count

get_num_rows = cached_remote_fn(_get_num_rows)

return sum(
ray.get(
[get_num_rows.remote(block) for block in self.get_internal_block_refs()]
)
)
# Directly loop over the iterator of `RefBundle`s instead of
# retrieving a full list of `BlockRef`s.
total_rows = 0
for ref_bundle in self.iter_internal_ref_bundles():
num_rows = ref_bundle.num_rows()
# Executing the dataset always returns blocks with valid `num_rows`.
assert num_rows is not None
total_rows += num_rows
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not an issue with this PR. but we should make BlockMetadata.num_rows non-nullable, to avoid repeating this check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's definitely do this when we separate BlockMetadata from read tasks. Currently, BlockMetadata.num_rows must be nullable because some datasources don't know how many rows are yielded by each read task

return total_rows

@ConsumptionAPI(
if_more_than_read=True,
Expand Down Expand Up @@ -4328,14 +4334,15 @@ def to_pandas(self, limit: int = None) -> "pandas.DataFrame":
ValueError: if the number of rows in the :class:`~ray.data.Dataset` exceeds
``limit``.
"""
count = self.count()
if limit is not None and count > limit:
raise ValueError(
f"the dataset has more than the given limit of {limit} "
f"rows: {count}. If you are sure that a DataFrame with "
f"{count} rows will fit in local memory, set ds.to_pandas(limit=None) "
"to disable limits."
)
if limit is not None:
count = self.count()
if count > limit:
raise ValueError(
f"the dataset has more than the given limit of {limit} "
f"rows: {count}. If you are sure that a DataFrame with "
f"{count} rows will fit in local memory, set "
"ds.to_pandas(limit=None) to disable limits."
)
blocks = self.get_internal_block_refs()
output = DelegatingBlockBuilder()
for block in blocks:
Expand Down Expand Up @@ -4563,7 +4570,36 @@ def stats(self) -> str:
def _get_stats_summary(self) -> DatasetStatsSummary:
return self._plan.stats_summary()

@ConsumptionAPI(pattern="Time complexity:")
@ConsumptionAPI(pattern="Examples:")
@DeveloperAPI
def iter_internal_ref_bundles(self) -> Iterator[RefBundle]:
"""Get an iterator over ``RefBundles``
belonging to this Dataset. Calling this function doesn't keep
the data materialized in-memory.

Examples:
>>> import ray
>>> ds = ray.data.range(1)
>>> for ref_bundle in ds.iter_internal_ref_bundles():
... for block_ref, block_md in ref_bundle.blocks:
... block = ray.get(block_ref)

Returns:
An iterator over this Dataset's ``RefBundles``.
"""

def _build_ref_bundles(
iter_blocks: Iterator[Tuple[ObjectRef[Block], BlockMetadata]],
) -> Iterator[RefBundle]:
for block in iter_blocks:
yield RefBundle((block,), owns_blocks=True)

iter_block_refs_md, _, _ = self._plan.execute_to_iterator()
iter_ref_bundles = _build_ref_bundles(iter_block_refs_md)
self._synchronize_progress_bar()
return iter_ref_bundles

@ConsumptionAPI(pattern="Examples:")
@DeveloperAPI
def get_internal_block_refs(self) -> List[ObjectRef[Block]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(can do this later) there are only a few use cases of get_internal_block_refs, we can also update them to use iter_internal_block_refs.

"""Get a list of references to the underlying blocks of this dataset.
Expand All @@ -4577,11 +4613,11 @@ def get_internal_block_refs(self) -> List[ObjectRef[Block]]:
>>> ds.get_internal_block_refs()
[ObjectRef(...)]

Time complexity: O(1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing this because it's no longer accurate.


Returns:
A list of references to this dataset's blocks.
"""
# TODO(scottjlee): replace get_internal_block_refs() usages with
# iter_internal_ref_bundles()
block_refs = self._plan.execute().block_refs
self._synchronize_progress_bar()
return block_refs
Expand Down
13 changes: 13 additions & 0 deletions python/ray/data/tests/test_consumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,19 @@ def test_count_edge_case(ray_start_regular):
assert actual_count == 5


def test_count_after_partial_execution(ray_start_regular):
paths = ["example://iris.csv"] * 5
ds = ray.data.read_csv(paths, override_num_blocks=15)
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
for batch in ds.iter_batches(batch_size=1):
# Take one batch and break to simulate partial iteration/execution.
break
# Row count should be unknown after partial execution.
assert "num_rows=?" in str(ds)
# After calling `ds.count()`, row count should be known.
assert ds.count() == 150 * 5
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
assert f"num_rows={150*5}" in str(ds)


def test_limit_execution(ray_start_regular):
last_snapshot = get_initial_core_execution_metrics_snapshot()
override_num_blocks = 20
Expand Down
17 changes: 17 additions & 0 deletions python/ray/data/tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,23 @@ def test_get_internal_block_refs(ray_start_regular_shared):
assert out == list(range(10)), out


def test_iter_internal_ref_bundles(ray_start_regular_shared):
n = 10
ds = ray.data.range(n, override_num_blocks=n)
iter_ref_bundles = ds.iter_internal_ref_bundles()

out = []
ref_bundle_count = 0
for ref_bundle in iter_ref_bundles:
for block_ref, block_md in ref_bundle.blocks:
b = ray.get(block_ref)
out.extend(extract_values("id", BlockAccessor.for_block(b).iter_rows(True)))
ref_bundle_count += 1
out = sorted(out)
assert ref_bundle_count == n
assert out == list(range(n)), out


def test_fsspec_filesystem(ray_start_regular_shared, tmp_path):
"""Same as `test_parquet_write` but using a custom, fsspec filesystem.

Expand Down