From 402951b84dc8c917d19191fa1de3eb70f40b0d9a Mon Sep 17 00:00:00 2001 From: Kartik Pradeepan Date: Fri, 10 Jan 2025 14:51:58 -0500 Subject: [PATCH] Update model.py --- brainscore_vision/models/ReAlnet01/model.py | 23 +++++++-------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/brainscore_vision/models/ReAlnet01/model.py b/brainscore_vision/models/ReAlnet01/model.py index f815c0680..42b7d47fd 100644 --- a/brainscore_vision/models/ReAlnet01/model.py +++ b/brainscore_vision/models/ReAlnet01/model.py @@ -11,13 +11,9 @@ import torch.nn.functional as F import h5py import random - import functools - import torchvision.models - -import gdown - +from brainscore_vision.model_helpers.s3 import load_weight_file from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images @@ -190,14 +186,14 @@ def forward(self, imgs): # Build encoder model encoder = Encoder(realnet, 340) -# Download weights -url = 'https://drive.google.com/uc?id=1AtpS7dPV8t3e1aT8a4Nu-mIFkbWRH-ff' -output_file = "best_model_params.pt" -gdown.download(url, output_file, quiet=False) - +# Download weights (Brain-Score team modification) +weights_path = load_weight_file(bucket="brainscore-storage", folder_name="brainscore-vision/models", + relative_path="ReAlnet01/best_model_params.pt", + version_id="3EduTJ.gv2rlVA_W1b7KSkOfyldAWIDc", + sha1="05e4e401e8734b97e561aad306fc584b7e027225") # Load weights onto CPU and remove "module." from keys -weights = torch.load(output_file, map_location='cpu') +weights = torch.load(weights_path, map_location='cpu') new_state_dict = {} for key, val in weights.items(): # remove "module." (if it exists) from the key @@ -206,11 +202,6 @@ def forward(self, imgs): encoder.load_state_dict(new_state_dict) - -# Load weights onto CPU -# weights = torch.load(output_file, map_location='cpu') -# encoder.load_state_dict(weights) - # Retrieve the realnet portion from the encoder realnet = encoder.realnet realnet.eval()