diff --git a/dask_cuda/explicit_comms/comms.py b/dask_cuda/explicit_comms/comms.py index 0fe5422d8..aa2b71c81 100644 --- a/dask_cuda/explicit_comms/comms.py +++ b/dask_cuda/explicit_comms/comms.py @@ -1,15 +1,21 @@ +# Copyright (c) 2021-2025 NVIDIA CORPORATION. import asyncio import concurrent.futures import contextlib import time import uuid +import weakref from typing import Any, Dict, Hashable, Iterable, List, Optional import distributed.comm +from dask.tokenize import tokenize from distributed import Client, Worker, default_client, get_worker from distributed.comm.addressing import parse_address, parse_host_port, unparse_address -_default_comms = None +# Mapping tokenize(client ID, [worker addresses]) to CommsContext +_comms_cache: weakref.WeakValueDictionary[ + str, "CommsContext" +] = weakref.WeakValueDictionary() def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs): @@ -38,9 +44,10 @@ def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs): def default_comms(client: Optional[Client] = None) -> "CommsContext": - """Return the default comms object + """Return the default comms object for ``client``. - Creates a new default comms object if no one exist. + Creates a new default comms object if one does not already exist + for ``client``. Parameters ---------- @@ -52,11 +59,39 @@ def default_comms(client: Optional[Client] = None) -> "CommsContext": ------- comms: CommsContext The default comms object + + Notes + ----- + There are some subtle points around explicit-comms and the lifecycle + of a Dask Cluster. + + A :class:`CommsContext` establishes explicit communication channels + between the workers *at the time it's created*. If workers are added + or removed, they will not be included in the communication channels + with the other workers. + + If you need to refresh the explicit communications channels, then + create a new :class:`CommsContext` object or call ``default_comms`` + again after workers have been added to or removed from the cluster. """ - global _default_comms - if _default_comms is None: - _default_comms = CommsContext(client=client) - return _default_comms + # Comms are unique to a client, so we need to know. + client = client or default_client() + # What behavior do we want with two clients referring to the same cluster? + # a = Client(cluster); b = Client(cluster) + # Their `.id`s are different, but don't they have the same comms? + # From the "this does stuff to the workers" point of view, yes they're + # the same. But `CommsClient` has a reference to `.client`, so not quite. + # + # Recommendation: either ignore this (multiple clients are common) or + # deprecate `CommsClient.client` and make users pass in + # `client` when they call stuff (default of deault_client()). + token = tokenize(client.id, list(client.scheduler_info()["workers"].keys())) + maybe_comms = _comms_cache.get(token) + if maybe_comms is None: + maybe_comms = CommsContext(client=client) + _comms_cache[token] = maybe_comms + + return maybe_comms def worker_state(sessionId: Optional[int] = None) -> dict: diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index 2f79251dd..913c9b038 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -1,3 +1,6 @@ +#!/bin/bash +# Copyright (c) 2021-2025 NVIDIA CORPORATION. + import asyncio import multiprocessing as mp import os @@ -415,3 +418,79 @@ def test_lock_workers(): p.join() assert all(p.exitcode == 0 for p in ps) + + +def test_create_destroy_create(): + # https://github.com/rapidsai/dask-cuda/issues/1450 + assert len(comms._comms_cache) == 0 + with LocalCluster(n_workers=1) as cluster: + with Client(cluster) as client: + context = comms.default_comms() + scheduler_addresses_old = list(client.scheduler_info()["workers"].keys()) + comms_addresses_old = list(comms.default_comms().worker_addresses) + assert comms.default_comms() is context + assert len(comms._comms_cache) == 1 + + # Add a worker, which should have a new comms object + cluster.scale(2) + client.wait_for_workers(2) + context2 = comms.default_comms() + assert context is not context2 + assert len(comms._comms_cache) == 2 + + del context + del context2 + assert len(comms._comms_cache) == 0 + assert scheduler_addresses_old == comms_addresses_old + + # A new cluster should have a new comms object. Previously, this failed + # because we referenced the old cluster's addresses. + with LocalCluster(n_workers=1) as cluster: + with Client(cluster) as client: + scheduler_addresses_new = list(client.scheduler_info()["workers"].keys()) + comms_addresses_new = list(comms.default_comms().worker_addresses) + + assert scheduler_addresses_new == comms_addresses_new + + +def test_update(): + cluster = LocalCluster(n_workers=2) + client = cluster.get_client() + context_1 = comms.default_comms() + + def check(dask_worker, session_id: int): + has_state = hasattr(dask_worker, "_explicit_comm_state") + has_state_for_session = ( + has_state and session_id in dask_worker._explicit_comm_state + ) + if has_state_for_session: + n_workers = dask_worker._explicit_comm_state[session_id]["nworkers"] + else: + n_workers = None + return { + "has_state": has_state, + "has_state_for_session": has_state_for_session, + "n_workers": n_workers, + } + + result_1 = client.run(check, session_id=context_1.sessionId) + expected_values = { + "has_state": True, + "has_state_for_session": True, + "n_workers": 2, + } + expected_1 = {k: expected_values for k in client.scheduler_info()["workers"]} + assert result_1 == expected_1 + + cluster.scale(3) + client.wait_for_workers(3, timeout=5) + + context_2 = comms.default_comms() + result_2 = client.run(check, session_id=context_2.sessionId) + expected_values = { + "has_state": True, + "has_state_for_session": True, + "n_workers": 3, + } + expected_2 = {k: expected_values for k in client.scheduler_info()["workers"]} + assert result_2 == expected_2 diff --git a/docs/source/explicit_comms.rst b/docs/source/explicit_comms.rst index bd024dbba..00e04241b 100644 --- a/docs/source/explicit_comms.rst +++ b/docs/source/explicit_comms.rst @@ -2,7 +2,7 @@ Explicit-comms ============== Communication and scheduling overhead can be a major bottleneck in Dask/Distributed. Dask-CUDA addresses this by introducing an API for explicit communication in Dask tasks. -The idea is that Dask/Distributed spawns workers and distribute data as usually while the user can submit tasks on the workers that communicate explicitly. +The idea is that Dask/Distributed spawns workers and distribute data as usual while the user can submit tasks on the workers that communicate explicitly. This makes it possible to bypass Distributed's scheduler and write hand-tuned computation and communication patterns. Currently, Dask-CUDA includes an explicit-comms implementation of the Dataframe `shuffle <../api/#dask_cuda.explicit_comms.dataframe.shuffle.shuffle>`_ operation used for merging and sorting.