Skip to content

Commit

Permalink
Avoid stale CommContext in explicit comms
Browse files Browse the repository at this point in the history
This PR updates the CommContext caching to be keyed by some information
about the cluster, rather than a single global. This prevents us from
using a stale comms object after the cluster changes (add or remove
workers) or is recreated entirely.

Closes rapidsai#1450
  • Loading branch information
TomAugspurger committed Feb 14, 2025
1 parent 4795210 commit ce3957a
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 8 deletions.
49 changes: 42 additions & 7 deletions dask_cuda/explicit_comms/comms.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
# Copyright (c) 2021-2025 NVIDIA CORPORATION.
import asyncio
import concurrent.futures
import contextlib
import time
import uuid
import weakref
from typing import Any, Dict, Hashable, Iterable, List, Optional

import distributed.comm
from dask.tokenize import tokenize
from distributed import Client, Worker, default_client, get_worker
from distributed.comm.addressing import parse_address, parse_host_port, unparse_address

_default_comms = None
# Mapping tokenize(client ID, [worker addresses]) to CommsContext
_comms_cache: weakref.WeakValueDictionary[
str, "CommsContext"
] = weakref.WeakValueDictionary()


def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs):
Expand Down Expand Up @@ -38,9 +44,10 @@ def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs):


def default_comms(client: Optional[Client] = None) -> "CommsContext":
"""Return the default comms object
"""Return the default comms object for ``client``.
Creates a new default comms object if no one exist.
Creates a new default comms object if one does not already exist
for ``client``.
Parameters
----------
Expand All @@ -52,11 +59,39 @@ def default_comms(client: Optional[Client] = None) -> "CommsContext":
-------
comms: CommsContext
The default comms object
Notes
-----
There are some subtle points around explicit-comms and the lifecycle
of a Dask Cluster.
A :class:`CommsContext` establishes explicit communication channels
between the workers *at the time it's created*. If workers are added
or removed, they will not be included in the communication channels
with the other workers.
If you need to refresh the explicit communications channels, then
create a new :class:`CommsContext` object or call ``default_comms``
again after workers have been added to or removed from the cluster.
"""
global _default_comms
if _default_comms is None:
_default_comms = CommsContext(client=client)
return _default_comms
# Comms are unique to a client, so we need to know.
client = client or default_client()
# What behavior do we want with two clients referring to the same cluster?
# a = Client(cluster); b = Client(cluster)
# Their `.id`s are different, but don't they have the same comms?
# From the "this does stuff to the workers" point of view, yes they're
# the same. But `CommsClient` has a reference to `.client`, so not quite.
#
# Recommendation: either ignore this (multiple clients are common) or
# deprecate `CommsClient.client` and make users pass in
# `client` when they call stuff (default of deault_client()).
token = tokenize(client.id, list(client.scheduler_info()["workers"].keys()))
maybe_comms = _comms_cache.get(token)
if maybe_comms is None:
maybe_comms = CommsContext(client=client)
_comms_cache[token] = maybe_comms

return maybe_comms


def worker_state(sessionId: Optional[int] = None) -> dict:
Expand Down
79 changes: 79 additions & 0 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#!/bin/bash
# Copyright (c) 2021-2025 NVIDIA CORPORATION.

import asyncio
import multiprocessing as mp
import os
Expand Down Expand Up @@ -415,3 +418,79 @@ def test_lock_workers():
p.join()

assert all(p.exitcode == 0 for p in ps)


def test_create_destroy_create():
# https://github.com/rapidsai/dask-cuda/issues/1450
assert len(comms._comms_cache) == 0
with LocalCluster(n_workers=1) as cluster:
with Client(cluster) as client:
context = comms.default_comms()
scheduler_addresses_old = list(client.scheduler_info()["workers"].keys())
comms_addresses_old = list(comms.default_comms().worker_addresses)
assert comms.default_comms() is context
assert len(comms._comms_cache) == 1

# Add a worker, which should have a new comms object
cluster.scale(2)
client.wait_for_workers(2)
context2 = comms.default_comms()
assert context is not context2
assert len(comms._comms_cache) == 2

del context
del context2
assert len(comms._comms_cache) == 0
assert scheduler_addresses_old == comms_addresses_old

# A new cluster should have a new comms object. Previously, this failed
# because we referenced the old cluster's addresses.
with LocalCluster(n_workers=1) as cluster:
with Client(cluster) as client:
scheduler_addresses_new = list(client.scheduler_info()["workers"].keys())
comms_addresses_new = list(comms.default_comms().worker_addresses)

assert scheduler_addresses_new == comms_addresses_new


def test_update():
cluster = LocalCluster(n_workers=2)
client = cluster.get_client()
context_1 = comms.default_comms()

def check(dask_worker, session_id: int):
has_state = hasattr(dask_worker, "_explicit_comm_state")
has_state_for_session = (
has_state and session_id in dask_worker._explicit_comm_state
)
if has_state_for_session:
n_workers = dask_worker._explicit_comm_state[session_id]["nworkers"]
else:
n_workers = None
return {
"has_state": has_state,
"has_state_for_session": has_state_for_session,
"n_workers": n_workers,
}

result_1 = client.run(check, session_id=context_1.sessionId)
expected_values = {
"has_state": True,
"has_state_for_session": True,
"n_workers": 2,
}
expected_1 = {k: expected_values for k in client.scheduler_info()["workers"]}
assert result_1 == expected_1

cluster.scale(3)
client.wait_for_workers(3, timeout=5)

context_2 = comms.default_comms()
result_2 = client.run(check, session_id=context_2.sessionId)
expected_values = {
"has_state": True,
"has_state_for_session": True,
"n_workers": 3,
}
expected_2 = {k: expected_values for k in client.scheduler_info()["workers"]}
assert result_2 == expected_2
2 changes: 1 addition & 1 deletion docs/source/explicit_comms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Explicit-comms
==============

Communication and scheduling overhead can be a major bottleneck in Dask/Distributed. Dask-CUDA addresses this by introducing an API for explicit communication in Dask tasks.
The idea is that Dask/Distributed spawns workers and distribute data as usually while the user can submit tasks on the workers that communicate explicitly.
The idea is that Dask/Distributed spawns workers and distribute data as usual while the user can submit tasks on the workers that communicate explicitly.

This makes it possible to bypass Distributed's scheduler and write hand-tuned computation and communication patterns. Currently, Dask-CUDA includes an explicit-comms
implementation of the Dataframe `shuffle <../api/#dask_cuda.explicit_comms.dataframe.shuffle.shuffle>`_ operation used for merging and sorting.
Expand Down

0 comments on commit ce3957a

Please sign in to comment.