diff --git a/benchmarks/benchmark_optimizer.py b/benchmarks/benchmark_optimizer.py index 8f93ff692..b976f3347 100644 --- a/benchmarks/benchmark_optimizer.py +++ b/benchmarks/benchmark_optimizer.py @@ -98,7 +98,7 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: grad_scaler = hivemind.GradScaler() else: # check that hivemind.Optimizer supports regular PyTorch grad scaler as well - grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp) + grad_scaler = torch.amp.GradScaler(enabled=args.use_amp) prev_time = time.perf_counter() @@ -107,7 +107,7 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: batch = torch.randint(0, len(X_train), (batch_size,)) - with torch.cuda.amp.autocast() if args.use_amp else nullcontext(): + with torch.amp.autocast() if args.use_amp else nullcontext(): loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device)) grad_scaler.scale(loss).backward()