diff --git a/litgpt/api.py b/litgpt/api.py index a114fdd512..ea156ce600 100644 --- a/litgpt/api.py +++ b/litgpt/api.py @@ -386,7 +386,7 @@ def distribute( model.eval() if generate_strategy == "sequential": - state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu") + state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu", weights_only=False) model.load_state_dict(state_dict, assign=True) model = fabric.setup_module(model, move_to_device=False) @@ -405,7 +405,7 @@ def distribute( pbar = tqdm(total=fabric.world_size, desc="Loading model weights") for rank in range(fabric.world_size): if fabric.global_rank == rank: - state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu") + state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu", weights_only=False) model.load_state_dict(state_dict, assign=True) # cannot use `.setup_module` because it will wrap with DDP