Skip to content

Commit

Permalink
add --compile
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Dec 27, 2024
1 parent 35d2fe5 commit b312385
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions examples/trompt_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def run(rank: int, world_size: int, args: argparse.Namespace) -> None:
col_names_dict=train_dataset.tensor_frame.col_names_dict,
).to(rank)
model = DistributedDataParallel(model, device_ids=[rank])
model = torch.compile(model) if args.compile else model
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
lr_scheduler = ExponentialLR(optimizer, gamma=0.95)

Expand Down Expand Up @@ -240,6 +241,7 @@ def run(rank: int, world_size: int, args: argparse.Namespace) -> None:
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--compile", action="store_true")
args = parser.parse_args()

os.environ['MASTER_ADDR'] = 'localhost'
Expand Down

0 comments on commit b312385

Please sign in to comment.