Skip to content

Commit

Permalink
restore fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorrent committed Jun 13, 2024
1 parent 4c52dbf commit 3276c94
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions benchmarks/benchmark_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose:
grad_scaler = hivemind.GradScaler()
else:
# check that hivemind.Optimizer supports regular PyTorch grad scaler as well
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
grad_scaler = torch.amp.GradScaler(enabled=args.use_amp)

prev_time = time.perf_counter()

Expand All @@ -107,7 +107,7 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose:

batch = torch.randint(0, len(X_train), (batch_size,))

with torch.cuda.amp.autocast() if args.use_amp else nullcontext():
with torch.amp.autocast() if args.use_amp else nullcontext():
loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device))
grad_scaler.scale(loss).backward()

Expand Down

0 comments on commit 3276c94

Please sign in to comment.