Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit Comms Object Not Cleared After Cluster State Change or Restart #1450

Closed
VibhuJawa opened this issue Feb 13, 2025 · 8 comments · Fixed by #1451
Closed

Explicit Comms Object Not Cleared After Cluster State Change or Restart #1450

VibhuJawa opened this issue Feb 13, 2025 · 8 comments · Fixed by #1451

Comments

@VibhuJawa
Copy link
Member

VibhuJawa commented Feb 13, 2025

Description

The explicit_comms module in dask_cuda does not properly reset or clear up after a cluster state change or when a new cluster is started. This results in default_comms().worker_addresses retaining references to old worker addresses, even after the original cluster has been cleaned up. The communication object does not respect the lifetime of the worker or cluster objects.

import dask_cuda
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from dask_cuda.explicit_comms import comms

# Start a cluster and check the worker addresses
with LocalCUDACluster(n_workers=1) as cluster, Client(cluster) as client:
    print(client.scheduler_info()["workers"].keys())  # Expected: Current worker
    print(comms.default_comms().worker_addresses)  # Expected: Current worker

Both addresses align

dict_keys(['tcp://127.0.0.1:43585'])
['tcp://127.0.0.1:43585']
# At this point, the cluster has been cleaned up,
# but comms.default_comms() still retains the old worker address
print(comms.default_comms().worker_addresses)  
['tcp://127.0.0.1:43585'] # Unexpected: Old worker address still present
# Start a new cluster and check worker addresses
with LocalCUDACluster(n_workers=1) as cluster, Client(cluster) as client:
    print(client.scheduler_info()["workers"].keys())  # Expected: New worker
    print(comms.default_comms().worker_addresses)  # Unexpected: Old worker address still retained

Addressees do not align anymore

dict_keys(['tcp://127.0.0.1:38569']) # Expected: New worker
['tcp://127.0.0.1:43585'] # Unexpected: Old worker address still retained

Observed Behavior

When the first cluster is cleaned up, comms.default_comms().worker_addresses still references the old worker.

After starting a new cluster, default_comms().worker_addresses continues to reference the worker from the previous cluster rather than updating to the new worker.

Expected Behavior

comms.default_comms().worker_addresses should be cleared or reset when the cluster is shut down or workers are updated.
When a new cluster is started, default_comms().worker_addresses should reflect the new worker(s) only.

Additional Context:

This issue led to multiple unexpected CI failures in NeMo Curator (PR #540), which took significant effort to diagnose and debug.

CC: @ayushdg , @sarahyurick who did a lot of work in triaging this issue.

@TomAugspurger
Copy link
Contributor

Thanks for the excellent reproducer.

It seems like we need the cached _default_comms object that lives in the comms module to have some notion of which client it belongs to, so that we don't accidentally cross streams clients / clusters come and go. We can do this with a dictionary (or a WeakValueDictionary to avoid keeping these CommContext objects alive, if they're expensive; Though I'm not sure who else would have a reference to the Comm object so maybe we do need somebody to keep it alive?):

diff --git a/dask_cuda/explicit_comms/comms.py b/dask_cuda/explicit_comms/comms.py
index 0fe5422..4ff5f76 100644
--- a/dask_cuda/explicit_comms/comms.py
+++ b/dask_cuda/explicit_comms/comms.py
@@ -1,4 +1,5 @@
 import asyncio
+import weakref
 import concurrent.futures
 import contextlib
 import time
@@ -9,7 +10,8 @@ import distributed.comm
 from distributed import Client, Worker, default_client, get_worker
 from distributed.comm.addressing import parse_address, parse_host_port, unparse_address
 
-_default_comms = None
+# Mapping client ID to CommsContext
+_comms_cache: weakref.WeakValueDictionary[str, "CommsContext"] = weakref.WeakValueDictionary()
 
 
 def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs):
@@ -53,10 +55,14 @@ def default_comms(client: Optional[Client] = None) -> "CommsContext":
     comms: CommsContext
         The default comms object
     """
-    global _default_comms
-    if _default_comms is None:
-        _default_comms = CommsContext(client=client)
-    return _default_comms
+    # Comms are unique to a client, so we need to know.
+    client = client or default_client()
+    maybe_comms = _comms_cache.get(client.id)
+    if maybe_comms is None:
+        maybe_comms = CommsContext(client=client)
+        _comms_cache[client.id] = maybe_comms
+
+    return maybe_comms

That passes a test ensuring that the defualt_comms match the worker addresses for a newly created cluster.

comms.default_comms().worker_addresses should be cleared or reset when the cluster is shut down or workers are updated.
When a new cluster is started, default_comms().worker_addresses should reflect the new worker(s) only.

I'll need to better understand the "workers are updated" case. Normally I would recommend a scheduler plugin, but this seems to live on the client.

@rjzamora
Copy link
Member

rjzamora commented Feb 13, 2025

Thanks for looking into this @TomAugspurger !

I was thinking of a very similar fix. My one qualm with your current suggestion is that we may want to create a token for the specific worker addresses and client.id value. E.g.

from dask.tokenize import tokenize

cache_key = tokenize(client.id, client.scheduler_info()["workers"].keys())

I realize it may be an "edge case", but it's possible that the user could add or loose workers between shuffle operations, and we want to make sure the comms context is updated if the worker addresses are different for the same client. Does that make sense?

@VibhuJawa
Copy link
Member Author

VibhuJawa commented Feb 13, 2025

I realize it may be an "edge case", but it's possible that the user could add or loose workers between shuffle operations, and we want to make sure the comms context is updated if the worker addresses are different for the same client. Does that make sense?

Honestly, in Curator, we lose workers a decent number of times at scale, If we use explicit comms shuffle twice and lose workers in between, that could be problematic. We should ensure this case is captured as well.

@TomAugspurger
Copy link
Contributor

Having never used explicit comms before, what's the expected behavior when the cluster changes (workers added or removed / die) after the comms have been established?

I see that dask_cuda.explicit_comms.dataframe.shuffle starts with a call to comms.default_comms(). If we include the current set of workers (using their addresses) in the cache key like @rjzamora suggested, then we'd ensure that subsequent shuffles with a "modified" cluster we'd have a fresh comms with an up to date picture of the cluster (though it looks like part of setting up CommsContext does some stuff to to worker, so we need to make sure that's able to be done multiple times if necessary, for the worker to update its understanding of the world. Maybe that's already possible).

@rjzamora
Copy link
Member

what's the expected behavior when the cluster changes (workers added or removed / die) after the comms have been established?

The original explicit-comms design definitely assumed the cluster was static, and made no attempt to handle worker-failures. I don't think we should try to handle the case that workers are lost during a shuffle. However, I do think it's reasonable to handle the case that workers are lost in between shuffles (unless this support proves unrealistic).

(though it looks like part of setting up CommsContext does some stuff to to worker, so we need to make sure that's able to be done multiple times if necessary, for the worker to update its understanding of the world. Maybe that's already possible).

Right - I haven't exactly vetted this idea yet, so it may not be realistic. If that turns out to be the case, it may still make sense to keep track of the worker addresses and warn the user that the workers have changed.

@TomAugspurger
Copy link
Contributor

TomAugspurger commented Feb 13, 2025

All that makes sense. Within-shuffle failures seem extremely hard to reason about. Making sure that comm.default_comms() does something sensible between shuffles should be easy (assuming it's safe to rerun the setup as new workers come online). I'll put up a PR in a bit or early tomorrow.

@rjzamora
Copy link
Member

Awesome - Thanks @TomAugspurger !

@VibhuJawa - For now, you may be able to add the following code before you spin up a new dask-cuda cluster to manually clear the cache:

import dask_cuda.explicit_comms.comms

dask_cuda.explicit_comms.comms._default_comms = None

TomAugspurger added a commit to TomAugspurger/dask-cuda that referenced this issue Feb 14, 2025
This PR updates the CommContext caching to be keyed by some information
about the cluster, rather than a single global. This prevents us from
using a stale comms object after the cluster changes (add or remove
workers) or is recreated entirely.

Closes rapidsai#1450
TomAugspurger added a commit to TomAugspurger/dask-cuda that referenced this issue Feb 14, 2025
This PR updates the CommContext caching to be keyed by some information
about the cluster, rather than a single global. This prevents us from
using a stale comms object after the cluster changes (add or remove
workers) or is recreated entirely.

Closes rapidsai#1450
TomAugspurger added a commit to TomAugspurger/dask-cuda that referenced this issue Feb 14, 2025
This PR updates the CommContext caching to be keyed by some information
about the cluster, rather than a single global. This prevents us from
using a stale comms object after the cluster changes (add or remove
workers) or is recreated entirely.

Closes rapidsai#1450
@TomAugspurger
Copy link
Contributor

#1451 should take care of this.

rapids-bot bot pushed a commit that referenced this issue Feb 19, 2025
This PR updates the CommContext caching to be keyed by some information about the cluster, rather than a single global. This prevents us from using a stale comms object after the cluster changes (add or remove workers) or is recreated entirely.

Closes #1450

Authors:
  - Tom Augspurger (https://github.com/TomAugspurger)

Approvers:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

URL: #1451
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants