diff --git a/benchmarks/benchmark_optimizer.py b/benchmarks/benchmark_optimizer.py index 1218b00c2..8f93ff692 100644 --- a/benchmarks/benchmark_optimizer.py +++ b/benchmarks/benchmark_optimizer.py @@ -107,7 +107,7 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: batch = torch.randint(0, len(X_train), (batch_size,)) - with torch.amp.autocast() if args.use_amp else nullcontext(): + with torch.cuda.amp.autocast() if args.use_amp else nullcontext(): loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device)) grad_scaler.scale(loss).backward() diff --git a/hivemind/optim/grad_scaler.py b/hivemind/optim/grad_scaler.py index 6af94bf5e..704f859c9 100644 --- a/hivemind/optim/grad_scaler.py +++ b/hivemind/optim/grad_scaler.py @@ -4,10 +4,6 @@ from typing import Dict, Optional import torch -from packaging import version - -torch_version = torch.__version__.split("+")[0] - from torch.cuda.amp import GradScaler as TorchGradScaler from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state from torch.optim import Optimizer as TorchOptimizer diff --git a/tests/test_p2p_daemon.py b/tests/test_p2p_daemon.py index 7cfc70244..19135d875 100644 --- a/tests/test_p2p_daemon.py +++ b/tests/test_p2p_daemon.py @@ -39,12 +39,12 @@ async def test_daemon_killed_on_del(): @pytest.mark.asyncio async def test_startup_error_message(): - with pytest.raises(P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers|Daemon failed to start"): + with pytest.raises(P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"): await P2P.create( initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"] ) - with pytest.raises(P2PDaemonError, match=r"Daemon failed to start|error accepting connection"): + with pytest.raises(P2PDaemonError, match=r"Daemon failed to start in .+ seconds"): await P2P.create(startup_timeout=0.01) # Test that startup_timeout works diff --git a/tests/test_utils/p2p_daemon.py b/tests/test_utils/p2p_daemon.py index 83f86dfe3..ebf41e52a 100644 --- a/tests/test_utils/p2p_daemon.py +++ b/tests/test_utils/p2p_daemon.py @@ -8,6 +8,7 @@ from typing import NamedTuple from multiaddr import Multiaddr, protocols +from pkg_resources import resource_filename from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client @@ -15,8 +16,6 @@ TIMEOUT_DURATION = 30 # seconds -from pkg_resources import resource_filename - P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")