Skip to content

Commit

Permalink
add ReAlnet01 to models
Browse files Browse the repository at this point in the history
  • Loading branch information
Jenkins committed Jan 9, 2025
1 parent d7b842a commit 112c7b1
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 0 deletions.
8 changes: 8 additions & 0 deletions brainscore_vision/models/ReAlnet01/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from brainscore_vision import model_registry
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from .model import get_model, LAYERS

model_registry['ReAlnet01']= lambda: ModelCommitment(
identifier='ReAlnet01',
activations_model=get_model(),
layers=LAYERS)
223 changes: 223 additions & 0 deletions brainscore_vision/models/ReAlnet01/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import math
from collections import OrderedDict
import torch
from torch import nn
from torchvision import transforms
import torch.utils.model_zoo
import os
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import torch.nn.functional as F
import h5py
import random

import functools

import torchvision.models

import gdown

from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images

LAYERS = ['V1', 'V2', 'V4', 'IT', 'decoder.avgpool']

class Flatten(nn.Module):
"""
Helper module for flattening input tensor to 1-D for the use in Linear modules
"""
def forward(self, x):
return x.view(x.size(0), -1)

class Identity(nn.Module):
"""
Helper module that stores the current tensor. Useful for accessing by name
"""
def forward(self, x):
return x

class CORblock_S(nn.Module):
scale = 4 # scale of the bottleneck convolution channels

def __init__(self, in_channels, out_channels, times=1):
super().__init__()
self.times = times

self.conv_input = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.skip = nn.Conv2d(out_channels, out_channels,
kernel_size=1, stride=2, bias=False)
self.norm_skip = nn.BatchNorm2d(out_channels)

self.conv1 = nn.Conv2d(out_channels, out_channels * self.scale,
kernel_size=1, bias=False)
self.nonlin1 = nn.ReLU(inplace=True)

self.conv2 = nn.Conv2d(out_channels * self.scale, out_channels * self.scale,
kernel_size=3, stride=2, padding=1, bias=False)
self.nonlin2 = nn.ReLU(inplace=True)

self.conv3 = nn.Conv2d(out_channels * self.scale, out_channels,
kernel_size=1, bias=False)
self.nonlin3 = nn.ReLU(inplace=True)

self.output = Identity() # for an easy access to this block's output

# need BatchNorm for each time step for training to work well
for t in range(self.times):
setattr(self, f'norm1_{t}', nn.BatchNorm2d(out_channels * self.scale))
setattr(self, f'norm2_{t}', nn.BatchNorm2d(out_channels * self.scale))
setattr(self, f'norm3_{t}', nn.BatchNorm2d(out_channels))

def forward(self, inp):
x = self.conv_input(inp)
for t in range(self.times):
if t == 0:
skip = self.norm_skip(self.skip(x))
self.conv2.stride = (2, 2)
else:
skip = x
self.conv2.stride = (1, 1)

x = self.conv1(x)
x = getattr(self, f'norm1_{t}')(x)
x = self.nonlin1(x)

x = self.conv2(x)
x = getattr(self, f'norm2_{t}')(x)
x = self.nonlin2(x)

x = self.conv3(x)
x = getattr(self, f'norm3_{t}')(x)

x += skip
x = self.nonlin3(x)
output = self.output(x)

return output

def CORnet_S():
model = nn.Sequential(OrderedDict([
('V1', nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)),
('norm1', nn.BatchNorm2d(64)),
('nonlin1', nn.ReLU(inplace=True)),
('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)),
('norm2', nn.BatchNorm2d(64)),
('nonlin2', nn.ReLU(inplace=True)),
('output', Identity())
]))),
('V2', CORblock_S(64, 128, times=2)),
('V4', CORblock_S(128, 256, times=4)),
('IT', CORblock_S(256, 512, times=2)),
('decoder', nn.Sequential(OrderedDict([
('avgpool', nn.AdaptiveAvgPool2d(1)),
('flatten', Flatten()),
('linear', nn.Linear(512, 1000)),
('output', Identity())
])))
]))

# weight initialization
for m in model.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

return model

class Encoder(nn.Module):
def __init__(self, realnet, n_output):
super(Encoder, self).__init__()

# CORnet
self.realnet = realnet

# fully connected layers
self.fc_v1 = nn.Linear(200704, 128)
self.fc_v2 = nn.Linear(100352, 128)
self.fc_v4 = nn.Linear(50176, 128)
self.fc_it = nn.Linear(25088, 128)
self.fc = nn.Linear(512, n_output)
self.activation = nn.ReLU()

def forward(self, imgs):
# forward pass through CORnet_S
outputs = self.realnet(imgs)

N = len(imgs)
v1_outputs = self.realnet.V1(imgs) # N * 64 * 56 * 56
v2_outputs = self.realnet.V2(v1_outputs) # N * 128 * 28 * 28
v4_outputs = self.realnet.V4(v2_outputs) # N * 256 * 14 * 14
it_outputs = self.realnet.IT(v4_outputs) # N * 512 * 7 * 7

# flatten and pass through fully connected layers
v1_features = self.fc_v1(v1_outputs.view(N, -1))
v1_features = self.activation(v1_features)

v2_features = self.fc_v2(v2_outputs.view(N, -1))
v2_features = self.activation(v2_features)

v4_features = self.fc_v4(v4_outputs.view(N, -1))
v4_features = self.activation(v4_features)

it_features = self.fc_it(it_outputs.view(N, -1))
it_features = self.activation(it_features)

features = torch.cat((v1_features, v2_features, v4_features, it_features), dim=1)
features = self.fc(features)

return outputs, features

# Change here: use 'cpu'
device = 'cpu'

transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Construct CORnet_S
realnet = CORnet_S()
# (Optional) remove DataParallel if not needed for CPU
# realnet = torch.nn.DataParallel(realnet)

# Build encoder model
encoder = Encoder(realnet, 340)

# Download weights
url = 'https://drive.google.com/uc?id=1AtpS7dPV8t3e1aT8a4Nu-mIFkbWRH-ff'
output_file = "best_model_params.pt"
gdown.download(url, output_file, quiet=False)


# Load weights onto CPU and remove "module." from keys
weights = torch.load(output_file, map_location='cpu')
new_state_dict = {}
for key, val in weights.items():
# remove "module." (if it exists) from the key
new_key = key.replace("module.", "")
new_state_dict[new_key] = val

encoder.load_state_dict(new_state_dict)


# Load weights onto CPU
# weights = torch.load(output_file, map_location='cpu')
# encoder.load_state_dict(weights)

# Retrieve the realnet portion from the encoder
realnet = encoder.realnet
realnet.eval()

def get_model():
model = realnet
preprocessing = functools.partial(load_preprocess_images, image_size=224)
wrapper = PytorchWrapper(identifier='ReAlnet01', model=model, preprocessing=preprocessing)
wrapper.image_size = 224
return wrapper
7 changes: 7 additions & 0 deletions brainscore_vision/models/ReAlnet01/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

torch
torchvision
pandas
numpy
h5py
gdown
15 changes: 15 additions & 0 deletions brainscore_vision/models/ReAlnet01/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
from pytest import approx

from brainscore_vision import score


@pytest.mark.private_access
@pytest.mark.memory_intense
@pytest.mark.parametrize("model_identifier, benchmark_identifier, expected_score", [
("ReAlnet01", "MajajHong2015.IT-pls", approx(0.0153, abs=0.0005)),
])
def test_score(model_identifier, benchmark_identifier, expected_score):
actual_score = score(model_identifier=model_identifier, benchmark_identifier=benchmark_identifier,
conda_active=True)
assert actual_score == expected_score

0 comments on commit 112c7b1

Please sign in to comment.