From 51e59426bcb3ef76bf00fe4aa1d91116e57078de Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 9 Jun 2024 22:39:04 +0100 Subject: [PATCH] Improve the process cleanup logic when running tests (#616) * Improve the process cleanup logic when running tests * Add graceful shutdown in test_dht_experts * Add graceful shutdown in test_fault_tolerance --- tests/conftest.py | 16 +++++++--------- tests/test_allreduce_fault_tolerance.py | 1 + tests/test_dht_experts.py | 6 ++++++ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 70f2535e1..0f747551f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ import asyncio import gc -from contextlib import suppress import psutil import pytest @@ -40,13 +39,12 @@ def cleanup_children(): children = psutil.Process().children(recursive=True) if children: - logger.info(f"Cleaning up {len(children)} leftover child processes") - for child in children: - with suppress(psutil.NoSuchProcess): - child.terminate() - psutil.wait_procs(children, timeout=1) - for child in children: - with suppress(psutil.NoSuchProcess): - child.kill() + gone, alive = psutil.wait_procs(children, timeout=0.1) + logger.debug(f"Cleaning up {len(alive)} leftover child processes") + for child in alive: + child.terminate() + gone, alive = psutil.wait_procs(alive, timeout=1) + for child in alive: + child.kill() MPFuture.reset_backend() diff --git a/tests/test_allreduce_fault_tolerance.py b/tests/test_allreduce_fault_tolerance.py index d1ee66d99..12e310eba 100644 --- a/tests/test_allreduce_fault_tolerance.py +++ b/tests/test_allreduce_fault_tolerance.py @@ -209,3 +209,4 @@ def _make_tensors(): for averager in averagers: averager.shutdown() + dht.shutdown() diff --git a/tests/test_dht_experts.py b/tests/test_dht_experts.py index 2862dcff5..0332a1a59 100644 --- a/tests/test_dht_experts.py +++ b/tests/test_dht_experts.py @@ -82,6 +82,10 @@ def test_beam_search( assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts) assert all(len(experts) == beam_size for experts in batch_experts) + you.shutdown() + for dht_instance in dht_instances: + dht_instance.shutdown() + @pytest.mark.forked def test_dht_single_node(): @@ -119,6 +123,8 @@ def test_dht_single_node(): with pytest.raises(AssertionError): beam_search.get_active_successors(["e.1.2.", "e.2", "e.4.5."]) + node.shutdown() + def test_uid_patterns(): valid_experts = [