Skip to content

Commit

Permalink
[core] Always make request on fetch_local in wait even if num_objects…
Browse files Browse the repository at this point in the history
… in memory (ray-project#50121)

We always want to make pull requests to get the objects passed into
`ray.wait` on different nodes when `ray.wait` is called with
`fetch_local`. Right now we don't make that request if we've already
gotten `num_objects` from the core worker memory store.

Closes ray-project#49257

---------

Signed-off-by: dayshah <[email protected]>
  • Loading branch information
dayshah authored Feb 16, 2025
1 parent 34a2d2c commit 52d5eb9
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 66 deletions.
5 changes: 3 additions & 2 deletions bazel/ray.bzl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_library_public")
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
load("@bazel_common//tools/maven:pom_file.bzl", "pom_file")
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_library_public")
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test")

COPTS_WITHOUT_LOG = select({
Expand All @@ -14,6 +14,7 @@ COPTS_WITHOUT_LOG = select({
"//conditions:default": [
"-Wunused-result",
"-Wconversion-null",
"-Wmisleading-indentation",
],
}) + select({
"//:clang-cl": [
Expand Down
2 changes: 2 additions & 0 deletions python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ py_test_module_list(
"test_unhandled_error.py",
"test_widgets.py",
"accelerators/test_accelerators.py",
"test_wait.py",
],
size = "small",
tags = ["exclusive", "small_size_python_tests", "team:core"],
Expand Down Expand Up @@ -724,6 +725,7 @@ py_test_module_list(
"test_basic_3.py",
"test_basic_4.py",
"test_basic_5.py",
"test_wait.py",
"test_multiprocessing.py",
"test_list_actors.py",
"test_list_actors_2.py",
Expand Down
16 changes: 0 additions & 16 deletions python/ray/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,22 +657,6 @@ def test_put_get(shutdown_only):
assert value_before == value_after


def test_wait_timing(shutdown_only):
ray.init(num_cpus=2)

@ray.remote
def f():
time.sleep(1)

future = f.remote()

start = time.time()
ready, not_ready = ray.wait([future], timeout=0.2)
assert 0.2 < time.time() - start < 0.3
assert len(ready) == 0
assert len(not_ready) == 1


@pytest.mark.skipif(client_test_enabled(), reason="internal _raylet")
def test_function_descriptor():
python_descriptor = ray._raylet.PythonFunctionDescriptor(
Expand Down
46 changes: 0 additions & 46 deletions python/ray/tests/test_basic_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,52 +592,6 @@ def call(actor):
assert ray.get(actor.get_num_threads.remote()) <= CONCURRENCY


def test_wait(ray_start_regular_shared):
@ray.remote
def f(delay):
time.sleep(delay)
return

object_refs = [f.remote(0), f.remote(0), f.remote(0), f.remote(0)]
ready_ids, remaining_ids = ray.wait(object_refs)
assert len(ready_ids) == 1
assert len(remaining_ids) == 3
ready_ids, remaining_ids = ray.wait(object_refs, num_returns=4)
assert set(ready_ids) == set(object_refs)
assert remaining_ids == []

object_refs = [f.remote(0), f.remote(5)]
ready_ids, remaining_ids = ray.wait(object_refs, timeout=0.5, num_returns=2)
assert len(ready_ids) == 1
assert len(remaining_ids) == 1

# Verify that calling wait with duplicate object refs throws an
# exception.
x = ray.put(1)
with pytest.raises(Exception):
ray.wait([x, x])

# Make sure it is possible to call wait with an empty list.
ready_ids, remaining_ids = ray.wait([])
assert ready_ids == []
assert remaining_ids == []

# Test semantics of num_returns with no timeout.
obj_refs = [ray.put(i) for i in range(10)]
(found, rest) = ray.wait(obj_refs, num_returns=2)
assert len(found) == 2
assert len(rest) == 8

# Verify that incorrect usage raises a TypeError.
x = ray.put(1)
with pytest.raises(TypeError):
ray.wait(x)
with pytest.raises(TypeError):
ray.wait(1)
with pytest.raises(TypeError):
ray.wait([1])


def test_duplicate_args(ray_start_regular_shared):
@ray.remote
def f(arg1, arg2, arg1_duplicate, kwarg1=None, kwarg2=None, kwarg1_duplicate=None):
Expand Down
141 changes: 141 additions & 0 deletions python/ray/tests/test_wait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# coding: utf-8

import pytest
import numpy as np
import time
import logging
import sys
import os

from ray._private.test_utils import client_test_enabled


if client_test_enabled():
from ray.util.client import ray
else:
import ray
import ray.util.state

logger = logging.getLogger(__name__)


def test_wait(ray_start_regular):
@ray.remote
def f(delay):
time.sleep(delay)
return

object_refs = [f.remote(0), f.remote(0), f.remote(0), f.remote(0)]
ready_ids, remaining_ids = ray.wait(object_refs)
assert len(ready_ids) == 1
assert len(remaining_ids) == 3
ready_ids, remaining_ids = ray.wait(object_refs, num_returns=4)
assert set(ready_ids) == set(object_refs)
assert remaining_ids == []

object_refs = [f.remote(0), f.remote(5)]
ready_ids, remaining_ids = ray.wait(object_refs, timeout=0.5, num_returns=2)
assert len(ready_ids) == 1
assert len(remaining_ids) == 1

# Verify that calling wait with duplicate object refs throws an
# exception.
x = ray.put(1)
with pytest.raises(Exception):
ray.wait([x, x])

# Make sure it is possible to call wait with an empty list.
ready_ids, remaining_ids = ray.wait([])
assert ready_ids == []
assert remaining_ids == []

# Test semantics of num_returns with no timeout.
obj_refs = [ray.put(i) for i in range(10)]
(found, rest) = ray.wait(obj_refs, num_returns=2)
assert len(found) == 2
assert len(rest) == 8

# Verify that incorrect usage raises a TypeError.
x = ray.put(1)
with pytest.raises(TypeError):
ray.wait(x)
with pytest.raises(TypeError):
ray.wait(1)
with pytest.raises(TypeError):
ray.wait([1])


def test_wait_timing(ray_start_2_cpus):
@ray.remote
def f():
time.sleep(1)

future = f.remote()

start = time.time()
ready, not_ready = ray.wait([future], timeout=0.2)
assert 0.2 < time.time() - start < 0.3
assert len(ready) == 0
assert len(not_ready) == 1


@pytest.mark.skipif(client_test_enabled(), reason="util not available with ray client")
def test_wait_always_fetch_local(monkeypatch, ray_start_cluster):
monkeypatch.setenv("RAY_scheduler_report_pinned_bytes_only", "false")
cluster = ray_start_cluster
head_node = cluster.add_node(num_cpus=0, object_store_memory=300e6)
ray.init(address=cluster.address)
worker_node = cluster.add_node(num_cpus=1, object_store_memory=300e6)

@ray.remote(num_cpus=1)
def return_large_object():
# 100mb so will spill on worker, but not once on head
return np.zeros(100 * 1024 * 1024, dtype=np.uint8)

@ray.remote(num_cpus=0)
def small_local_task():
return 1

put_on_head = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
head_node.node_id, soft=False
)
put_on_worker = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
worker_node.node_id, soft=False
)
x = small_local_task.options(scheduling_strategy=put_on_head).remote()
y = return_large_object.options(scheduling_strategy=put_on_worker).remote()
z = return_large_object.options(scheduling_strategy=put_on_worker).remote()

# will return when tasks are done
ray.wait([x, y, z], num_returns=3, fetch_local=False)
assert (
ray._private.state.available_resources_per_node()[head_node.node_id][
"object_store_memory"
]
> 250e6
)

# x should be immediately available locally, start fetching y and z
ray.wait([x, y, z], num_returns=1, fetch_local=True)
assert (
ray._private.state.available_resources_per_node()[head_node.node_id][
"object_store_memory"
]
> 250e6
)

time.sleep(5)
# y, z should be pulled here
assert (
ray._private.state.available_resources_per_node()[head_node.node_id][
"object_store_memory"
]
< 150e6
)


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
sys.exit(pytest.main(["-sv", __file__]))
5 changes: 4 additions & 1 deletion src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2094,7 +2094,10 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids,
if (fetch_local) {
MoveReadyPlasmaObjectsToPlasmaSet(
memory_store_, memory_object_ids, plasma_object_ids, ready);
if (static_cast<int>(ready.size()) < num_objects && !plasma_object_ids.empty()) {
// We make the request to the plasma store even if we have num_objects ready since we
// want to at least make the request to pull these objects if the user specified
// fetch_local so the pulling can start.
if (!plasma_object_ids.empty()) {
RAY_RETURN_NOT_OK(plasma_store_provider_->Wait(
plasma_object_ids,
std::min(static_cast<int>(plasma_object_ids.size()),
Expand Down
25 changes: 25 additions & 0 deletions src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,31 @@ void NodeManager::ProcessWaitRequestMessage(
current_task_id,
/*ray_get=*/false);
}
if (message->num_ready_objects() == 0) {
// If we don't need to wait for any, return immediately after making the pull
// requests through AsyncResolveObjects above.
flatbuffers::FlatBufferBuilder fbb;
auto wait_reply = protocol::CreateWaitReply(fbb,
to_flatbuf(fbb, std::vector<ObjectID>{}),
to_flatbuf(fbb, std::vector<ObjectID>{}));
fbb.Finish(wait_reply);
const auto status =
client->WriteMessage(static_cast<int64_t>(protocol::MessageType::WaitReply),
fbb.GetSize(),
fbb.GetBufferPointer());
if (status.ok()) {
if (resolve_objects) {
AsyncResolveObjectsFinish(client, current_task_id);
}
} else {
// We failed to write to the client, so disconnect the client.
std::ostringstream stream;
stream << "Failed to write WaitReply to the client. Status " << status
<< ", message: " << status.message();
DisconnectClient(client, rpc::WorkerExitType::SYSTEM_ERROR, stream.str());
}
return;
}
uint64_t num_required_objects = static_cast<uint64_t>(message->num_ready_objects());
wait_manager_.Wait(
object_ids,
Expand Down
1 change: 0 additions & 1 deletion src/ray/raylet/wait_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ void WaitManager::Wait(const std::vector<ObjectID> &object_ids,
<< "Waiting duplicate objects is not allowed. Please make sure all object IDs are "
"unique before calling `WaitManager::Wait`.";
RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1);
RAY_CHECK_NE(num_required_objects, 0u);
RAY_CHECK_LE(num_required_objects, object_ids.size());

const uint64_t wait_id = next_wait_id_++;
Expand Down

0 comments on commit 52d5eb9

Please sign in to comment.