Skip to content

Commit

Permalink
weights_only=True for all the occurences of torch.load
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Jan 30, 2025
1 parent 519ee8a commit 57a5298
Show file tree
Hide file tree
Showing 9 changed files with 15 additions and 14 deletions.
3 changes: 2 additions & 1 deletion terratorch/io/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ def load_torch_weights(model:nn.Module=None, save_dir: str = None, name: str = N
torch.load(
os.path.join(save_dir, name),
map_location=torch.device(device),
weights_only=True,
)
)
else:
model.load_state_dict(torch.load(os.path.join(save_dir, name), map_location='cpu'))
model.load_state_dict(torch.load(os.path.join(save_dir, name), map_location='cpu', weights_only=True))

except Exception:
print(
Expand Down
2 changes: 1 addition & 1 deletion terratorch/models/backbones/clay_v1/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(

def load_clay_weights(self, ckpt_path):
"Load the weights from the Clay model encoder."
ckpt = torch.load(ckpt_path)
ckpt = torch.load(ckpt_path, weights_only=True)
state_dict = ckpt.get("state_dict")
state_dict = {
re.sub(r"^model\.encoder\.", "", name): param
Expand Down
2 changes: 1 addition & 1 deletion terratorch/models/backbones/dofa_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def load_dofa_weights(model: nn.Module, ckpt_data: str | None = None, weights:
repo_id = ckpt_data.split("/resolve/")[0].replace("https://hf.co/", '')
filename = ckpt_data.split("/")[-1]
ckpt_data = huggingface_hub.hf_hub_download(repo_id=repo_id, filename=filename)
checkpoint_model = torch.load(ckpt_data, map_location="cpu")
checkpoint_model = torch.load(ckpt_data, map_location="cpu", weights_only=True)

for k in ["head.weight", "head.bias"]:
if (
Expand Down
4 changes: 2 additions & 2 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _create_prithvi(

if ckpt_path is not None:
# Load model from checkpoint
state_dict = torch.load(ckpt_path, map_location="cpu")
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
state_dict = checkpoint_filter_wrapper_fn(state_dict, model, pretrained_bands, model_bands)
model.load_state_dict(state_dict, strict=False)
elif pretrained:
Expand All @@ -225,7 +225,7 @@ def _create_prithvi(
# Load model from Hugging Face
pretrained_path = hf_hub_download(repo_id=pretrained_weights[variant]["hf_hub_id"],
filename=pretrained_weights[variant]["hf_hub_filename"])
state_dict = torch.load(pretrained_path, map_location="cpu")
state_dict = torch.load(pretrained_path, map_location="cpu", weights_only=True)
state_dict = checkpoint_filter_wrapper_fn(state_dict, model, pretrained_bands, model_bands)
model.load_state_dict(state_dict, strict=True)
except RuntimeError as e:
Expand Down
2 changes: 1 addition & 1 deletion terratorch/models/backbones/scalemae.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def vit_huge_patch14(**kwargs):
return model

def load_scalemae_weights(model: nn.Module, ckpt_data: str, model_bands: list[HLSBands], input_size: int = 224) -> nn.Module:
checkpoint_model = torch.load(ckpt_data, map_location="cpu")["model"]
checkpoint_model = torch.load(ckpt_data, map_location="cpu", weights_only=True)["model"]
state_dict = model.state_dict()

for k in ["head.weight", "head.bias"]:
Expand Down
4 changes: 2 additions & 2 deletions terratorch/models/backbones/torchgeo_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,8 @@ def load_resnet_weights(model: nn.Module, model_bands, ckpt_data: str, weights:
repo_id = ckpt_data.split("/resolve/")[0].replace("https://hf.co/", '')
filename = ckpt_data.split("/")[-1]
ckpt_data = huggingface_hub.hf_hub_download(repo_id=repo_id, filename=filename)
# checkpoint_model = torch.load(ckpt_data, map_location="cpu")["model"]
checkpoint_model = torch.load(ckpt_data, map_location="cpu")

checkpoint_model = torch.load(ckpt_data, map_location="cpu", weights_only=True)
state_dict = model.state_dict()

for k in ["fc.weight", "fc.bias"]:
Expand Down
4 changes: 2 additions & 2 deletions terratorch/models/backbones/torchgeo_swin_satlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ def load_swin_weights(model: nn.Module, model_bands, ckpt_data: str, weights: We
repo_id = ckpt_data.split("/resolve/")[0].replace("https://hf.co/", '')
filename = ckpt_data.split("/")[-1]
ckpt_data = huggingface_hub.hf_hub_download(repo_id=repo_id, filename=filename)
# checkpoint_model = torch.load(ckpt_data, map_location="cpu")["model"]
checkpoint_model = torch.load(ckpt_data, map_location="cpu")

checkpoint_model = torch.load(ckpt_data, map_location="cpu", weights_only=True)
state_dict = model.state_dict()

for k in ["head.weight", "head.bias"]:
Expand Down
4 changes: 2 additions & 2 deletions terratorch/models/backbones/torchgeo_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ def load_vit_weights(model: nn.Module, model_bands, ckpt_data: str, weights: Wei
repo_id = ckpt_data.split("/resolve/")[0].replace("https://hf.co/", '')
filename = ckpt_data.split("/")[-1]
ckpt_data = huggingface_hub.hf_hub_download(repo_id=repo_id, filename=filename)
# checkpoint_model = torch.load(ckpt_data, map_location="cpu")["model"]
checkpoint_model = torch.load(ckpt_data, map_location="cpu")

checkpoint_model = torch.load(ckpt_data, map_location="cpu", weights_only=True)
state_dict = model.state_dict()

for k in ["head.weight", "head.bias"]:
Expand Down
4 changes: 2 additions & 2 deletions terratorch/models/satmae_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ def build_model(
backbone: nn.Module = ModelWrapper(model=backbone_template(**backbone_kwargs), kind=backbone_kind)

if self.CPU_ONLY:
model_dict = torch.load(checkpoint_path, map_location="cpu")
model_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
else:
model_dict = torch.load(checkpoint_path)
model_dict = torch.load(checkpoint_path, weights_only=True)


# Filtering parameters from the model state_dict (when necessary)
Expand Down

0 comments on commit 57a5298

Please sign in to comment.