-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Data] Update Dataset.count()
to avoid unnecessarily keeping BlockRef
s in-memory
#46369
Changes from 15 commits
c313a52
09b905d
5ef8041
172e423
374898b
b0ec894
f0b49f1
aa368d4
ab8d6ed
128614d
65c3f4e
7956f35
958d306
d6a1d79
87151ee
3ea9759
b6df226
70f82de
f520c62
629d6bb
4720bf6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
Dict, | ||
Generic, | ||
Iterable, | ||
Iterator, | ||
List, | ||
Literal, | ||
Mapping, | ||
|
@@ -77,6 +78,7 @@ | |
VALID_BATCH_FORMATS, | ||
Block, | ||
BlockAccessor, | ||
BlockMetadata, | ||
DataBatch, | ||
T, | ||
U, | ||
|
@@ -2459,12 +2461,14 @@ def show(self, limit: int = 20) -> None: | |
@ConsumptionAPI( | ||
if_more_than_read=True, | ||
datasource_metadata="row count", | ||
pattern="Time complexity:", | ||
pattern="Examples:", | ||
) | ||
def count(self) -> int: | ||
"""Count the number of records in the dataset. | ||
"""Count the number of rows in the dataset. | ||
|
||
Time complexity: O(dataset size / parallelism), O(1) for parquet | ||
For Datasets which only read Parquet files (created with | ||
:meth:`~ray.data.read_parquet`), this method reads the file metadata to | ||
efficiently count the number of rows without reading in the entire data. | ||
|
||
Examples: | ||
>>> import ray | ||
|
@@ -2484,13 +2488,15 @@ def count(self) -> int: | |
if meta_count is not None: | ||
return meta_count | ||
|
||
get_num_rows = cached_remote_fn(_get_num_rows) | ||
|
||
return sum( | ||
ray.get( | ||
[get_num_rows.remote(block) for block in self.get_internal_block_refs()] | ||
) | ||
) | ||
# Directly loop over the iterator of `RefBundle`s instead of | ||
# retrieving a full list of `BlockRef`s. | ||
total_rows = 0 | ||
for ref_bundle in self.iter_internal_ref_bundles(): | ||
num_rows = ref_bundle.num_rows() | ||
# Executing the dataset always returns blocks with valid `num_rows`. | ||
assert num_rows is not None | ||
total_rows += num_rows | ||
return total_rows | ||
|
||
@ConsumptionAPI( | ||
if_more_than_read=True, | ||
|
@@ -4328,14 +4334,15 @@ def to_pandas(self, limit: int = None) -> "pandas.DataFrame": | |
ValueError: if the number of rows in the :class:`~ray.data.Dataset` exceeds | ||
``limit``. | ||
""" | ||
count = self.count() | ||
if limit is not None and count > limit: | ||
raise ValueError( | ||
f"the dataset has more than the given limit of {limit} " | ||
f"rows: {count}. If you are sure that a DataFrame with " | ||
f"{count} rows will fit in local memory, set ds.to_pandas(limit=None) " | ||
"to disable limits." | ||
) | ||
if limit is not None: | ||
count = self.count() | ||
if count > limit: | ||
raise ValueError( | ||
f"the dataset has more than the given limit of {limit} " | ||
f"rows: {count}. If you are sure that a DataFrame with " | ||
f"{count} rows will fit in local memory, set " | ||
"ds.to_pandas(limit=None) to disable limits." | ||
) | ||
blocks = self.get_internal_block_refs() | ||
output = DelegatingBlockBuilder() | ||
for block in blocks: | ||
|
@@ -4563,7 +4570,35 @@ def stats(self) -> str: | |
def _get_stats_summary(self) -> DatasetStatsSummary: | ||
return self._plan.stats_summary() | ||
|
||
@ConsumptionAPI(pattern="Time complexity:") | ||
@ConsumptionAPI(pattern="Examples:") | ||
@DeveloperAPI | ||
def iter_internal_ref_bundles(self) -> Iterator[RefBundle]: | ||
"""Get an iterator over ``RefBundles`` | ||
belonging to this Dataset. Calling this function doesn't keep | ||
the data materialized in-memory. | ||
|
||
Examples: | ||
>>> import ray | ||
>>> ds = ray.data.range(1) | ||
>>> for ref_bundle in ds.iter_internal_ref_bundles(): | ||
... for block_ref, block_md in ref_bundle.blocks: | ||
... block = ray.get(block_ref) | ||
|
||
Returns: | ||
An iterator over this Dataset's ``RefBundles``. | ||
""" | ||
|
||
def _build_ref_bundle( | ||
blocks: Tuple[ObjectRef[Block], BlockMetadata], | ||
) -> RefBundle: | ||
return RefBundle((blocks,), owns_blocks=True) | ||
|
||
iter_block_refs_md, _, _ = self._plan.execute_to_iterator() | ||
iter_ref_bundles = map(_build_ref_bundle, iter_block_refs_md) | ||
self._synchronize_progress_bar() | ||
return iter_ref_bundles | ||
|
||
@ConsumptionAPI(pattern="Examples:") | ||
@DeveloperAPI | ||
def get_internal_block_refs(self) -> List[ObjectRef[Block]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (can do this later) there are only a few use cases of |
||
"""Get a list of references to the underlying blocks of this dataset. | ||
|
@@ -4577,11 +4612,11 @@ def get_internal_block_refs(self) -> List[ObjectRef[Block]]: | |
>>> ds.get_internal_block_refs() | ||
[ObjectRef(...)] | ||
|
||
Time complexity: O(1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removing this because it's no longer accurate. |
||
|
||
Returns: | ||
A list of references to this dataset's blocks. | ||
""" | ||
# TODO(scottjlee): replace get_internal_block_refs() usages with | ||
# iter_internal_ref_bundles() | ||
block_refs = self._plan.execute().block_refs | ||
self._synchronize_progress_bar() | ||
return block_refs | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,34 +82,50 @@ def test_zip_pandas(ray_start_regular_shared): | |
ds2 = ray.data.from_pandas(pd.DataFrame({"col3": ["a", "b"], "col4": ["d", "e"]})) | ||
ds = ds1.zip(ds2) | ||
assert ds.count() == 2 | ||
assert "{col1: int64, col2: int64, col3: object, col4: object}" in str(ds) | ||
|
||
result = list(ds.take()) | ||
assert result[0] == {"col1": 1, "col2": 4, "col3": "a", "col4": "d"} | ||
|
||
# Execute the dataset to get full schema. | ||
ds = ds.materialize() | ||
assert "{col1: int64, col2: int64, col3: object, col4: object}" in str(ds) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to update tests for |
||
ds3 = ray.data.from_pandas(pd.DataFrame({"col2": ["a", "b"], "col4": ["d", "e"]})) | ||
ds = ds1.zip(ds3) | ||
assert ds.count() == 2 | ||
assert "{col1: int64, col2: int64, col2_1: object, col4: object}" in str(ds) | ||
|
||
result = list(ds.take()) | ||
assert result[0] == {"col1": 1, "col2": 4, "col2_1": "a", "col4": "d"} | ||
|
||
# Execute the dataset to get full schema. | ||
ds = ds.materialize() | ||
assert "{col1: int64, col2: int64, col2_1: object, col4: object}" in str(ds) | ||
|
||
|
||
def test_zip_arrow(ray_start_regular_shared): | ||
ds1 = ray.data.range(5).map(lambda r: {"id": r["id"]}) | ||
ds2 = ray.data.range(5).map(lambda r: {"a": r["id"] + 1, "b": r["id"] + 2}) | ||
ds = ds1.zip(ds2) | ||
assert ds.count() == 5 | ||
assert "{id: int64, a: int64, b: int64}" in str(ds) | ||
|
||
result = list(ds.take()) | ||
assert result[0] == {"id": 0, "a": 1, "b": 2} | ||
|
||
# Execute the dataset to get full schema. | ||
ds = ds.materialize() | ||
assert "{id: int64, a: int64, b: int64}" in str(ds) | ||
|
||
# Test duplicate column names. | ||
ds = ds1.zip(ds1).zip(ds1) | ||
assert ds.count() == 5 | ||
assert "{id: int64, id_1: int64, id_2: int64}" in str(ds) | ||
|
||
result = list(ds.take()) | ||
assert result[0] == {"id": 0, "id_1": 0, "id_2": 0} | ||
|
||
# Execute the dataset to get full schema. | ||
ds = ds.materialize() | ||
assert "{id: int64, id_1: int64, id_2: int64}" in str(ds) | ||
|
||
|
||
def test_zip_multiple_block_types(ray_start_regular_shared): | ||
df = pd.DataFrame({"spam": [0]}) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not an issue with this PR. but we should make
BlockMetadata.num_rows
non-nullable, to avoid repeating this check.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's definitely do this when we separate
BlockMetadata
from read tasks. Currently,BlockMetadata.num_rows
must be nullable because some datasources don't know how many rows are yielded by each read task