Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Update Dataset.count() to avoid unnecessarily keeping BlockRefs in-memory #46369

Merged
merged 21 commits into from
Jul 10, 2024
Merged
1 change: 1 addition & 0 deletions doc/source/data/api/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ Inspecting Metadata
Dataset.input_files
Dataset.stats
Dataset.get_internal_block_refs
Dataset.iter_internal_ref_bundles

Execution
---------
Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ def execute_to_iterator(

metrics_tag = create_dataset_tag(self._dataset_name, self._dataset_uuid)
executor = StreamingExecutor(copy.deepcopy(ctx.execution_options), metrics_tag)
# TODO(scottjlee): replace with `execute_to_legacy_bundle_iterator` and
# update execute_to_iterator usages to handle RefBundles instead of Blocks
block_iter = execute_to_legacy_block_iterator(
executor,
self,
Expand Down
77 changes: 56 additions & 21 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Dict,
Generic,
Iterable,
Iterator,
List,
Literal,
Mapping,
Expand Down Expand Up @@ -77,6 +78,7 @@
VALID_BATCH_FORMATS,
Block,
BlockAccessor,
BlockMetadata,
DataBatch,
T,
U,
Expand Down Expand Up @@ -2459,12 +2461,14 @@ def show(self, limit: int = 20) -> None:
@ConsumptionAPI(
if_more_than_read=True,
datasource_metadata="row count",
pattern="Time complexity:",
pattern="Examples:",
)
def count(self) -> int:
"""Count the number of records in the dataset.
"""Count the number of rows in the dataset.

Time complexity: O(dataset size / parallelism), O(1) for parquet
For Datasets which only read Parquet files (created with
:meth:`~ray.data.read_parquet`), this method reads the file metadata to
efficiently count the number of rows without reading in the entire data.

Examples:
>>> import ray
Expand All @@ -2484,13 +2488,15 @@ def count(self) -> int:
if meta_count is not None:
return meta_count

get_num_rows = cached_remote_fn(_get_num_rows)

return sum(
ray.get(
[get_num_rows.remote(block) for block in self.get_internal_block_refs()]
)
)
# Directly loop over the iterator of `RefBundle`s instead of
# retrieving a full list of `BlockRef`s.
total_rows = 0
for ref_bundle in self.iter_internal_ref_bundles():
num_rows = ref_bundle.num_rows()
# Executing the dataset always returns blocks with valid `num_rows`.
assert num_rows is not None
total_rows += num_rows
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not an issue with this PR. but we should make BlockMetadata.num_rows non-nullable, to avoid repeating this check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's definitely do this when we separate BlockMetadata from read tasks. Currently, BlockMetadata.num_rows must be nullable because some datasources don't know how many rows are yielded by each read task

return total_rows

@ConsumptionAPI(
if_more_than_read=True,
Expand Down Expand Up @@ -4328,14 +4334,15 @@ def to_pandas(self, limit: int = None) -> "pandas.DataFrame":
ValueError: if the number of rows in the :class:`~ray.data.Dataset` exceeds
``limit``.
"""
count = self.count()
if limit is not None and count > limit:
raise ValueError(
f"the dataset has more than the given limit of {limit} "
f"rows: {count}. If you are sure that a DataFrame with "
f"{count} rows will fit in local memory, set ds.to_pandas(limit=None) "
"to disable limits."
)
if limit is not None:
count = self.count()
if count > limit:
raise ValueError(
f"the dataset has more than the given limit of {limit} "
f"rows: {count}. If you are sure that a DataFrame with "
f"{count} rows will fit in local memory, set "
"ds.to_pandas(limit=None) to disable limits."
)
blocks = self.get_internal_block_refs()
output = DelegatingBlockBuilder()
for block in blocks:
Expand Down Expand Up @@ -4563,7 +4570,35 @@ def stats(self) -> str:
def _get_stats_summary(self) -> DatasetStatsSummary:
return self._plan.stats_summary()

@ConsumptionAPI(pattern="Time complexity:")
@ConsumptionAPI(pattern="Examples:")
@DeveloperAPI
def iter_internal_ref_bundles(self) -> Iterator[RefBundle]:
"""Get an iterator over ``RefBundles``
belonging to this Dataset. Calling this function doesn't keep
the data materialized in-memory.

Examples:
>>> import ray
>>> ds = ray.data.range(1)
>>> for ref_bundle in ds.iter_internal_ref_bundles():
... for block_ref, block_md in ref_bundle.blocks:
... block = ray.get(block_ref)

Returns:
An iterator over this Dataset's ``RefBundles``.
"""

def _build_ref_bundle(
blocks: Tuple[ObjectRef[Block], BlockMetadata],
) -> RefBundle:
return RefBundle((blocks,), owns_blocks=True)

iter_block_refs_md, _, _ = self._plan.execute_to_iterator()
iter_ref_bundles = map(_build_ref_bundle, iter_block_refs_md)
self._synchronize_progress_bar()
return iter_ref_bundles

@ConsumptionAPI(pattern="Examples:")
@DeveloperAPI
def get_internal_block_refs(self) -> List[ObjectRef[Block]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(can do this later) there are only a few use cases of get_internal_block_refs, we can also update them to use iter_internal_block_refs.

"""Get a list of references to the underlying blocks of this dataset.
Expand All @@ -4577,11 +4612,11 @@ def get_internal_block_refs(self) -> List[ObjectRef[Block]]:
>>> ds.get_internal_block_refs()
[ObjectRef(...)]

Time complexity: O(1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing this because it's no longer accurate.


Returns:
A list of references to this dataset's blocks.
"""
# TODO(scottjlee): replace get_internal_block_refs() usages with
# iter_internal_ref_bundles()
block_refs = self._plan.execute().block_refs
self._synchronize_progress_bar()
return block_refs
Expand Down
17 changes: 17 additions & 0 deletions python/ray/data/tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,23 @@ def test_get_internal_block_refs(ray_start_regular_shared):
assert out == list(range(10)), out


def test_iter_internal_ref_bundles(ray_start_regular_shared):
n = 10
ds = ray.data.range(n, override_num_blocks=n)
iter_ref_bundles = ds.iter_internal_ref_bundles()

out = []
ref_bundle_count = 0
for ref_bundle in iter_ref_bundles:
for block_ref, block_md in ref_bundle.blocks:
b = ray.get(block_ref)
out.extend(extract_values("id", BlockAccessor.for_block(b).iter_rows(True)))
ref_bundle_count += 1
out = sorted(out)
assert ref_bundle_count == n
assert out == list(range(n)), out


def test_fsspec_filesystem(ray_start_regular_shared, tmp_path):
"""Same as `test_parquet_write` but using a custom, fsspec filesystem.

Expand Down
24 changes: 20 additions & 4 deletions python/ray/data/tests/test_zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,34 +82,50 @@ def test_zip_pandas(ray_start_regular_shared):
ds2 = ray.data.from_pandas(pd.DataFrame({"col3": ["a", "b"], "col4": ["d", "e"]}))
ds = ds1.zip(ds2)
assert ds.count() == 2
assert "{col1: int64, col2: int64, col3: object, col4: object}" in str(ds)

result = list(ds.take())
assert result[0] == {"col1": 1, "col2": 4, "col3": "a", "col4": "d"}

# Execute the dataset to get full schema.
ds = ds.materialize()
assert "{col1: int64, col2: int64, col3: object, col4: object}" in str(ds)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update tests for Zip, because in the test, we call ds.count() before attempting to check the schema from the Dataset.__str__ representation. After updating ds.count() to no longer execute and get the list of underlying Blocks, the schema is unknown for N-ary operators without executing: https://github.com/ray-project/ray/blob/master/python/ray/data/_internal/plan.py#L155-L158

ds3 = ray.data.from_pandas(pd.DataFrame({"col2": ["a", "b"], "col4": ["d", "e"]}))
ds = ds1.zip(ds3)
assert ds.count() == 2
assert "{col1: int64, col2: int64, col2_1: object, col4: object}" in str(ds)

result = list(ds.take())
assert result[0] == {"col1": 1, "col2": 4, "col2_1": "a", "col4": "d"}

# Execute the dataset to get full schema.
ds = ds.materialize()
assert "{col1: int64, col2: int64, col2_1: object, col4: object}" in str(ds)


def test_zip_arrow(ray_start_regular_shared):
ds1 = ray.data.range(5).map(lambda r: {"id": r["id"]})
ds2 = ray.data.range(5).map(lambda r: {"a": r["id"] + 1, "b": r["id"] + 2})
ds = ds1.zip(ds2)
assert ds.count() == 5
assert "{id: int64, a: int64, b: int64}" in str(ds)

result = list(ds.take())
assert result[0] == {"id": 0, "a": 1, "b": 2}

# Execute the dataset to get full schema.
ds = ds.materialize()
assert "{id: int64, a: int64, b: int64}" in str(ds)

# Test duplicate column names.
ds = ds1.zip(ds1).zip(ds1)
assert ds.count() == 5
assert "{id: int64, id_1: int64, id_2: int64}" in str(ds)

result = list(ds.take())
assert result[0] == {"id": 0, "id_1": 0, "id_2": 0}

# Execute the dataset to get full schema.
ds = ds.materialize()
assert "{id: int64, id_1: int64, id_2: int64}" in str(ds)


def test_zip_multiple_block_types(ray_start_regular_shared):
df = pd.DataFrame({"spam": [0]})
Expand Down