diff --git a/tests/unit/common.py b/tests/unit/common.py index c9eb7ffaa5f4..69ba4c2708ac 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -25,6 +25,8 @@ # Worker timeout for tests that hang DEEPSPEED_TEST_TIMEOUT = int(os.environ.get('DS_UNITTEST_TIMEOUT', '600')) +warn_reuse_dist_env = False + def is_rocm_pytorch(): return hasattr(torch.version, 'hip') and torch.version.hip is not None @@ -179,6 +181,13 @@ def _launch_daemonic_procs(self, num_procs): print("Ignoring reuse_dist_env for hpu") self.reuse_dist_env = False + global warn_reuse_dist_env + if self.reuse_dist_env and not warn_reuse_dist_env: + # Currently we see memory leak for tests that reuse distributed environment + print("Ignoring reuse_dist_env and forcibly setting it to False") + warn_reuse_dist_env = True + self.reuse_dist_env = False + if self.reuse_dist_env: if num_procs not in self._pool_cache: self._pool_cache[num_procs] = mp.Pool(processes=num_procs)