Skip to content

Commit

Permalink
Merge branch 'main' into small_fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov authored Dec 15, 2024
2 parents a4e1134 + 4b3dd3b commit c01b31a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions litgpt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def distribute(
model.eval()

if generate_strategy == "sequential":
state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu")
state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu", weights_only=False)
model.load_state_dict(state_dict, assign=True)
model = fabric.setup_module(model, move_to_device=False)

Expand All @@ -405,7 +405,7 @@ def distribute(
pbar = tqdm(total=fabric.world_size, desc="Loading model weights")
for rank in range(fabric.world_size):
if fabric.global_rank == rank:
state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu")
state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu", weights_only=False)
model.load_state_dict(state_dict, assign=True)

# cannot use `.setup_module` because it will wrap with DDP
Expand Down

0 comments on commit c01b31a

Please sign in to comment.