Skip to content

Commit

Permalink
Linter updates
Browse files Browse the repository at this point in the history
  • Loading branch information
oguzhanbsolak committed Nov 21, 2024
1 parent 404b29c commit 06c4d8f
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,10 +641,11 @@ def flush(self):
model = torch.compile(model, mode=args.compiler_mode,
backend=args.compiler_backend)

#TODO: Optimize DDP is currently not supported with QAT.
# Once pytorch supports DDP with higher order ops, we can enable optimize DDP with QAT.
# https://github.com/pytorch/pytorch/issues/104674.
torch._dynamo.config.optimize_ddp=False
# TODO: Optimize DDP is currently not supported with QAT.
# Once pytorch supports DDP with higher order ops,
# we can enable optimize DDP with QAT.
# https://github.com/pytorch/pytorch/issues/104674.
torch._dynamo.config.optimize_ddp = False # pylint: disable=protected-access
msglogger.info(
'torch.compile() successful, mode=%s, cache limit=%d',
args.compiler_mode,
Expand Down Expand Up @@ -1089,17 +1090,16 @@ def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=No
model = apputils.load_lean_checkpoint(model, best_ckpt_path)

if ddp:
model = DistributedDataParallel(
model,
device_ids=[local_rank] if args.device == 'cuda' else None,
output_device=local_rank if args.device == 'cuda' else None,
)
model = DistributedDataParallel(
model,
device_ids=[local_rank] if args.device == 'cuda' else None,
output_device=local_rank if args.device == 'cuda' else None,
)

if dynamo:
torch._dynamo.reset() # pylint: disable=protected-access
model = torch.compile(model, mode=args.compiler_mode,
backend=args.compiler_backend,
)
backend=args.compiler_backend)
msglogger.info(
'torch.compile() successful, mode=%s, cache limit=%d',
args.compiler_mode,
Expand Down

0 comments on commit 06c4d8f

Please sign in to comment.