diff --git a/train.py b/train.py index 9f5c7ff89..9c8883772 100755 --- a/train.py +++ b/train.py @@ -641,10 +641,11 @@ def flush(self): model = torch.compile(model, mode=args.compiler_mode, backend=args.compiler_backend) - #TODO: Optimize DDP is currently not supported with QAT. - # Once pytorch supports DDP with higher order ops, we can enable optimize DDP with QAT. - # https://github.com/pytorch/pytorch/issues/104674. - torch._dynamo.config.optimize_ddp=False + # TODO: Optimize DDP is currently not supported with QAT. + # Once pytorch supports DDP with higher order ops, + # we can enable optimize DDP with QAT. + # https://github.com/pytorch/pytorch/issues/104674. + torch._dynamo.config.optimize_ddp = False # pylint: disable=protected-access msglogger.info( 'torch.compile() successful, mode=%s, cache limit=%d', args.compiler_mode, @@ -1089,17 +1090,16 @@ def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=No model = apputils.load_lean_checkpoint(model, best_ckpt_path) if ddp: - model = DistributedDataParallel( - model, - device_ids=[local_rank] if args.device == 'cuda' else None, - output_device=local_rank if args.device == 'cuda' else None, - ) + model = DistributedDataParallel( + model, + device_ids=[local_rank] if args.device == 'cuda' else None, + output_device=local_rank if args.device == 'cuda' else None, + ) if dynamo: torch._dynamo.reset() # pylint: disable=protected-access model = torch.compile(model, mode=args.compiler_mode, - backend=args.compiler_backend, - ) + backend=args.compiler_backend) msglogger.info( 'torch.compile() successful, mode=%s, cache limit=%d', args.compiler_mode,