Skip to content

Commit

Permalink
un-fix other things
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorrent committed Jun 13, 2024
1 parent 3cd9d38 commit d3593eb
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 9 deletions.
2 changes: 1 addition & 1 deletion benchmarks/benchmark_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose:

batch = torch.randint(0, len(X_train), (batch_size,))

with torch.amp.autocast() if args.use_amp else nullcontext():
with torch.cuda.amp.autocast() if args.use_amp else nullcontext():
loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device))
grad_scaler.scale(loss).backward()

Expand Down
4 changes: 0 additions & 4 deletions hivemind/optim/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
from typing import Dict, Optional

import torch
from packaging import version

torch_version = torch.__version__.split("+")[0]

from torch.cuda.amp import GradScaler as TorchGradScaler
from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
from torch.optim import Optimizer as TorchOptimizer
Expand Down
4 changes: 2 additions & 2 deletions tests/test_p2p_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ async def test_daemon_killed_on_del():

@pytest.mark.asyncio
async def test_startup_error_message():
with pytest.raises(P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers|Daemon failed to start"):
with pytest.raises(P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
await P2P.create(
initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
)

with pytest.raises(P2PDaemonError, match=r"Daemon failed to start|error accepting connection"):
with pytest.raises(P2PDaemonError, match=r"Daemon failed to start in .+ seconds"):
await P2P.create(startup_timeout=0.01) # Test that startup_timeout works


Expand Down
3 changes: 1 addition & 2 deletions tests/test_utils/p2p_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
from typing import NamedTuple

from multiaddr import Multiaddr, protocols
from pkg_resources import resource_filename

from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client

from test_utils.networking import get_free_port

TIMEOUT_DURATION = 30 # seconds

from pkg_resources import resource_filename

P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")


Expand Down

0 comments on commit d3593eb

Please sign in to comment.