Skip to content

Commit

Permalink
Merge branch 'master' into feat/add-job-filter
Browse files Browse the repository at this point in the history
  • Loading branch information
Betula-L authored Feb 7, 2025
2 parents 93d068a + a5d6042 commit 61d47fc
Show file tree
Hide file tree
Showing 22 changed files with 260 additions and 103 deletions.
22 changes: 22 additions & 0 deletions doc/source/data/transforming-data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,28 @@ To transform data with a Python class, complete these steps:

ds.materialize()

Avoiding out-of-memory errors
=============================

If your user defined function uses lots of memory, you might encounter out-of-memory
errors. To avoid these errors, configure the ``memory`` parameter. It tells Ray how much
memory your function uses, and prevents Ray from scheduling too many tasks on a node.

.. testcode::
:hide:

import ray

ds = ray.data.range(1)

.. testcode::

def uses_lots_of_memory(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
...

# Tell Ray that the function uses 1 GiB of memory
ds.map_batches(uses_lots_of_memory, memory=1 * 1024 * 1024)

.. _transforming_groupby:

Groupby and transforming groups
Expand Down
12 changes: 8 additions & 4 deletions doc/source/ray-core/compiled-graph/overlap.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,15 @@ from overlapping.
with InputNode() as inp:
branches = [sender.send.bind(shape, dtype, inp) for sender in senders]
branches = [
branch.with_type_hint(
TorchTensorType(
transport="nccl", _static_shape=True, _direct_return=True
)
branch.with_tensor_transport(
transport="nccl", _static_shape=True, _direct_return=True
)
# For a ray version before 2.42, use `with_type_hint()` instead.
# branch.with_type_hint(
# TorchTensorType(
# transport="nccl", _static_shape=True, _direct_return=True
# )
# )
for branch in branches
]
branches = [receiver.recv_and_matmul.bind(branch) for branch in branches]
Expand Down
12 changes: 9 additions & 3 deletions doc/source/ray-core/compiled-graph/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ This "hello world" example uses Ray Compiled Graph. First, install Ray.

.. testcode::

pip install "ray[adag]"
pip install "ray[cg]"

# For a ray version before 2.41, use the following instead:
# pip install "ray[adag]"


We will define a simple actor.

Expand Down Expand Up @@ -122,15 +126,17 @@ Next, create sender and receiver actors.
sender = GPUSender.remote()
receiver = GPUReceiver.remote()

To support GPU to GPU RDMA with NCCL, you can use ``with_type_hint`` API with Compiled Graph.
To support GPU to GPU RDMA with NCCL, you can use ``with_tensor_transport`` API with Compiled Graph.

.. testcode::

with ray.dag.InputNode() as inp:
dag = sender.send.bind(inp)
# It gives a type hint that the return value of `send` should use
# NCCL.
dag = dag.with_type_hint(TorchTensorType(transport="nccl"))
dag = dag.with_tensor_transport("nccl")
# Note that before ray version 2.42, use `with_type_hint()` instead.
# dag = dag.with_type_hint(TorchTensorType(transport="nccl"))
dag = receiver.recv.bind(dag)

# Compile API prepares the NCCL communicator across all workers and schedule operations
Expand Down
4 changes: 3 additions & 1 deletion doc/source/serve/production-guide/fault-tolerance.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ This section discusses concepts from:
(serve-e2e-ft-guide)=
## Guide: end-to-end fault tolerance for your Serve app

Serve provides some [fault tolerance](serve-ft-detail) features out of the box. You can provide end-to-end fault tolerance by tuning these features and running Serve on top of [KubeRay].
Serve provides some [fault tolerance](serve-ft-detail) features out of the box. Two options to get end-to-end fault tolerance are the following:
* tune these features and run Serve on top of [KubeRay]
* use the [Anyscale platform](https://docs.anyscale.com/platform/services/head-node-ft), a managed Ray platform

### Replica health-checking

Expand Down
1 change: 1 addition & 0 deletions python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import traceback
from collections import defaultdict
from contextlib import contextmanager, redirect_stderr, redirect_stdout

from typing import Any, Callable, Dict, List, Optional, Tuple
import uuid
from dataclasses import dataclass
Expand Down
2 changes: 1 addition & 1 deletion python/ray/dag/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ py_test_module_list(

py_test(
name = "test_torch_tensor_dag_gpu",
size = "large",
size = "enormous",
srcs = [
"tests/experimental/test_torch_tensor_dag.py",
],
Expand Down
61 changes: 33 additions & 28 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2090,34 +2090,39 @@ def teardown(self, kill_actors: bool = False):
return

logger.info("Tearing down compiled DAG")
outer._dag_submitter.close()
outer._dag_output_fetcher.close()

for actor in outer.actor_refs:
logger.info(f"Cancelling compiled worker on actor: {actor}")
# Cancel all actor loops in parallel.
cancel_refs = [
actor.__ray_call__.remote(do_cancel_executable_tasks, tasks)
for actor, tasks in outer.actor_to_executable_tasks.items()
]
for cancel_ref in cancel_refs:
try:
ray.get(cancel_ref, timeout=30)
except RayChannelError:
# Channel error happens when a channel is closed
# or timed out. In this case, do not log.
pass
except Exception:
logger.exception("Error cancelling worker task")
pass

for (
communicator_id
) in outer._actors_to_created_communicator_id.values():
_destroy_communicator(communicator_id)

logger.info("Waiting for worker tasks to exit")
self.wait_teardown(kill_actors=kill_actors)
try:
outer._dag_submitter.close()
outer._dag_output_fetcher.close()

for actor in outer.actor_refs:
logger.info(f"Cancelling compiled worker on actor: {actor}")
# Cancel all actor loops in parallel.
cancel_refs = [
actor.__ray_call__.remote(do_cancel_executable_tasks, tasks)
for actor, tasks in outer.actor_to_executable_tasks.items()
]
for cancel_ref in cancel_refs:
try:
ray.get(cancel_ref, timeout=30)
except RayChannelError:
# Channel error happens when a channel is closed
# or timed out. In this case, do not log.
pass
except Exception:
logger.exception("Error cancelling worker task")
pass

for (
communicator_id
) in outer._actors_to_created_communicator_id.values():
_destroy_communicator(communicator_id)

logger.info("Waiting for worker tasks to exit")
self.wait_teardown(kill_actors=kill_actors)
except ReferenceError:
# Python destruction order is not guaranteed, so we may
# access attributes of `outer` which are already destroyed.
logger.info("Compiled DAG is already destroyed")
logger.info("Teardown complete")
self._teardown_done = True

Expand Down
6 changes: 2 additions & 4 deletions python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ py_test(

py_test(
name = "test_image",
size = "small",
size = "medium",
srcs = ["tests/test_image.py"],
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
Expand Down Expand Up @@ -353,8 +353,6 @@ py_test(
deps = ["//:ray_lib", ":conftest"],
)

# Added "data_non_parallel" tag to prevent parallel execution of this test.
# It needs about 5GB memory.
py_test(
name = "test_dynamic_block_split",
size = "medium",
Expand Down Expand Up @@ -505,7 +503,7 @@ py_test(
name = "test_sort",
size = "enormous",
srcs = ["tests/test_sort.py"],
tags = ["team:data", "exclusive"],
tags = ["team:data", "exclusive", "data_non_parallel"],
deps = ["//:ray_lib", ":conftest"],
)

Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _table_from_pydict(columns: Dict[str, List[Any]]) -> Block:

@staticmethod
def _concat_tables(tables: List[Block]) -> Block:
return transform_pyarrow.concat(tables)
return transform_pyarrow.concat(tables, promote_types=True)

@staticmethod
def _concat_would_copy() -> bool:
Expand Down
26 changes: 17 additions & 9 deletions python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,9 @@ def _align_struct_fields(
return aligned_blocks


def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table":
def concat(
blocks: List["pyarrow.Table"], *, promote_types: bool = False
) -> "pyarrow.Table":
"""Concatenate provided Arrow Tables into a single Arrow Table. This has special
handling for extension types that pyarrow.concat_tables does not yet support.
"""
Expand All @@ -432,7 +434,7 @@ def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table":
# If the result contains pyarrow schemas, unify them
schemas_to_unify = [b.schema for b in blocks]
try:
schema = unify_schemas(schemas_to_unify)
schema = unify_schemas(schemas_to_unify, promote_types=promote_types)
except Exception as e:
raise ArrowConversionError(str(blocks)) from e

Expand Down Expand Up @@ -519,13 +521,19 @@ def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table":
else:
# No extension array columns, so use built-in pyarrow.concat_tables.

if parse_version(_get_pyarrow_version()) >= parse_version("14.0.0"):
# `promote` was superseded by `promote_options='default'` in Arrow 14. To
# prevent `FutureWarning`s, we manually check the Arrow version and use the
# appropriate parameter.
table = pyarrow.concat_tables(blocks, promote_options="default")
else:
# When concatenating tables we allow type promotions to occur, since
# no schema enforcement is currently performed, therefore allowing schemas
# to vary b/w blocks
#
# NOTE: Type promotions aren't available in Arrow < 14.0
if parse_version(_get_pyarrow_version()) < parse_version("14.0.0"):
table = pyarrow.concat_tables(blocks, promote=True)
else:
arrow_promote_types_mode = "permissive" if promote_types else "default"
table = pyarrow.concat_tables(
blocks, promote_options=arrow_promote_types_mode
)

return table


Expand All @@ -534,7 +542,7 @@ def concat_and_sort(
) -> "pyarrow.Table":
import pyarrow.compute as pac

ret = concat(blocks)
ret = concat(blocks, promote_types=True)
indices = pac.sort_indices(ret, sort_keys=sort_key.to_arrow_sort_args())
return take_table(ret, indices)

Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/datasource/parquet_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _write_partition_files(
import pyarrow as pa
import pyarrow.parquet as pq

table = concat(tables)
table = concat(tables, promote_types=False)
# Create unique combinations of the partition columns
table_fields = [
field for field in output_schema if field.name not in self.partition_cols
Expand Down
17 changes: 17 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def map(
fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
num_cpus: Optional[float] = None,
num_gpus: Optional[float] = None,
memory: Optional[float] = None,
concurrency: Optional[Union[int, Tuple[int, int]]] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
**ray_remote_args,
Expand Down Expand Up @@ -336,6 +337,7 @@ def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]:
num_gpus: The number of GPUs to reserve for each parallel map worker. For
example, specify `num_gpus=1` to request 1 GPU for each parallel map
worker.
memory: The heap memory in bytes to reserve for each parallel map worker.
concurrency: The number of Ray workers to use concurrently. For a fixed-sized
worker pool of size ``n``, specify ``concurrency=n``. For an autoscaling
worker pool from ``m`` to ``n`` workers, specify ``concurrency=(m, n)``.
Expand Down Expand Up @@ -371,6 +373,9 @@ def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]:
if num_gpus is not None:
ray_remote_args["num_gpus"] = num_gpus

if memory is not None:
ray_remote_args["memory"] = memory

plan = self._plan.copy()
map_op = MapRows(
self._logical_plan.dag,
Expand Down Expand Up @@ -413,6 +418,7 @@ def map_batches(
fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
num_cpus: Optional[float] = None,
num_gpus: Optional[float] = None,
memory: Optional[float] = None,
concurrency: Optional[Union[int, Tuple[int, int]]] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
**ray_remote_args,
Expand Down Expand Up @@ -561,6 +567,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
num_cpus: The number of CPUs to reserve for each parallel map worker.
num_gpus: The number of GPUs to reserve for each parallel map worker. For
example, specify `num_gpus=1` to request 1 GPU for each parallel map worker.
memory: The heap memory in bytes to reserve for each parallel map worker.
concurrency: The number of Ray workers to use concurrently. For a fixed-sized
worker pool of size ``n``, specify ``concurrency=n``. For an autoscaling
worker pool from ``m`` to ``n`` workers, specify ``concurrency=(m, n)``.
Expand Down Expand Up @@ -630,6 +637,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
fn_constructor_kwargs=fn_constructor_kwargs,
num_cpus=num_cpus,
num_gpus=num_gpus,
memory=memory,
concurrency=concurrency,
ray_remote_args_fn=ray_remote_args_fn,
**ray_remote_args,
Expand All @@ -649,6 +657,7 @@ def _map_batches_without_batch_size_validation(
fn_constructor_kwargs: Optional[Dict[str, Any]],
num_cpus: Optional[float],
num_gpus: Optional[float],
memory: Optional[float],
concurrency: Optional[Union[int, Tuple[int, int]]],
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]],
**ray_remote_args,
Expand All @@ -672,6 +681,9 @@ def _map_batches_without_batch_size_validation(
if num_gpus is not None:
ray_remote_args["num_gpus"] = num_gpus

if memory is not None:
ray_remote_args["memory"] = memory

batch_format = _apply_batch_format(batch_format)

min_rows_per_bundled_input = None
Expand Down Expand Up @@ -1103,6 +1115,7 @@ def flat_map(
fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
num_cpus: Optional[float] = None,
num_gpus: Optional[float] = None,
memory: Optional[float] = None,
concurrency: Optional[Union[int, Tuple[int, int]]] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
**ray_remote_args,
Expand Down Expand Up @@ -1168,6 +1181,7 @@ def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]:
num_gpus: The number of GPUs to reserve for each parallel map worker. For
example, specify `num_gpus=1` to request 1 GPU for each parallel map
worker.
memory: The heap memory in bytes to reserve for each parallel map worker.
concurrency: The number of Ray workers to use concurrently. For a
fixed-sized worker pool of size ``n``, specify ``concurrency=n``.
For an autoscaling worker pool from ``m`` to ``n`` workers, specify
Expand Down Expand Up @@ -1202,6 +1216,9 @@ def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]:
if num_gpus is not None:
ray_remote_args["num_gpus"] = num_gpus

if memory is not None:
ray_remote_args["memory"] = memory

plan = self._plan.copy()
op = FlatMap(
input_op=self._logical_plan.dag,
Expand Down
3 changes: 3 additions & 0 deletions python/ray/data/grouped_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def map_groups(
fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
num_cpus: Optional[float] = None,
num_gpus: Optional[float] = None,
memory: Optional[float] = None,
concurrency: Optional[Union[int, Tuple[int, int]]] = None,
**ray_remote_args,
) -> "Dataset":
Expand Down Expand Up @@ -175,6 +176,7 @@ def map_groups(
num_gpus: The number of GPUs to reserve for each parallel map worker. For
example, specify `num_gpus=1` to request 1 GPU for each parallel map
worker.
memory: The heap memory in bytes to reserve for each parallel map worker.
ray_remote_args: Additional resource requirements to request from
Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See
:func:`ray.remote` for details.
Expand Down Expand Up @@ -257,6 +259,7 @@ def wrapped_fn(batch, *args, **kwargs):
fn_constructor_kwargs=fn_constructor_kwargs,
num_cpus=num_cpus,
num_gpus=num_gpus,
memory=memory,
concurrency=concurrency,
ray_remote_args_fn=None,
**ray_remote_args,
Expand Down
Loading

0 comments on commit 61d47fc

Please sign in to comment.