Skip to content

Commit

Permalink
Update model.py (#1687)
Browse files Browse the repository at this point in the history
  • Loading branch information
KartikP authored Jan 10, 2025
1 parent 3352845 commit 85ffe43
Showing 1 changed file with 7 additions and 16 deletions.
23 changes: 7 additions & 16 deletions brainscore_vision/models/ReAlnet01/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,9 @@
import torch.nn.functional as F
import h5py
import random

import functools

import torchvision.models

import gdown

from brainscore_vision.model_helpers.s3 import load_weight_file
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images

Expand Down Expand Up @@ -190,14 +186,14 @@ def forward(self, imgs):
# Build encoder model
encoder = Encoder(realnet, 340)

# Download weights
url = 'https://drive.google.com/uc?id=1AtpS7dPV8t3e1aT8a4Nu-mIFkbWRH-ff'
output_file = "best_model_params.pt"
gdown.download(url, output_file, quiet=False)

# Download weights (Brain-Score team modification)
weights_path = load_weight_file(bucket="brainscore-storage", folder_name="brainscore-vision/models",
relative_path="ReAlnet01/best_model_params.pt",
version_id="3EduTJ.gv2rlVA_W1b7KSkOfyldAWIDc",
sha1="05e4e401e8734b97e561aad306fc584b7e027225")

# Load weights onto CPU and remove "module." from keys
weights = torch.load(output_file, map_location='cpu')
weights = torch.load(weights_path, map_location='cpu')
new_state_dict = {}
for key, val in weights.items():
# remove "module." (if it exists) from the key
Expand All @@ -206,11 +202,6 @@ def forward(self, imgs):

encoder.load_state_dict(new_state_dict)


# Load weights onto CPU
# weights = torch.load(output_file, map_location='cpu')
# encoder.load_state_dict(weights)

# Retrieve the realnet portion from the encoder
realnet = encoder.realnet
realnet.eval()
Expand Down

0 comments on commit 85ffe43

Please sign in to comment.