Skip to content

Commit

Permalink
Update pnasnet_large (#1755)
Browse files Browse the repository at this point in the history
* change identifiers, add layer hooks to pnasnet

* Add pnasnet_large.json to region_layer_map for model pnasnet_large

---------

Co-authored-by: KartikP <[email protected]>
  • Loading branch information
samwinebrake and KartikP authored Jan 24, 2025
1 parent 7361ea5 commit ad0dd5d
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from .model import get_model, get_layers

model_registry['pnasnet_large_pytorch'] = lambda: ModelCommitment(identifier='pnasnet_large_pytorch',
activations_model=get_model('pnasnet_large_pytorch'),
layers=get_layers('pnasnet_large_pytorch'))
model_registry['pnasnet_large'] = lambda: ModelCommitment(identifier='pnasnet_large',
activations_model=get_model('pnasnet_large'),
layers=get_layers('pnasnet_large'))
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,19 @@
MODEL = timm.create_model('pnasnet5large.tf_in1k', pretrained=True)

def get_model(name):
assert name == 'pnasnet_large_pytorch'
preprocessing = functools.partial(load_preprocess_images, image_size=331)
wrapper = PytorchWrapper(identifier='pnasnet_large_pytorch', model=MODEL,
assert name == 'pnasnet_large'
preprocessing = functools.partial(load_preprocess_images, image_size=331, preprocess_type='inception')
wrapper = PytorchWrapper(identifier='pnasnet_large', model=MODEL,
preprocessing=preprocessing,
batch_size=4) # doesn't fit into 12 GB GPU memory otherwise
wrapper.image_size = 331
return wrapper


def get_layers(name):
assert name == 'pnasnet_large_pytorch'
layer_names = []

for name, module in MODEL.named_modules():
layer_names.append(name)

return layer_names[2:]
assert name == 'pnasnet_large'
layer_names = [f'cell_{i + 1}' for i in range(-1, 11)] + ['global_pool']
return layer_names


def get_bibtex(model_identifier):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"V1": "cell_0",
"V2": "cell_4",
"V4": "cell_2",
"IT": "cell_8"
}
8 changes: 8 additions & 0 deletions brainscore_vision/models/pnasnet_large/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest
import brainscore_vision


@pytest.mark.travis_slow
def test_has_identifier():
model = brainscore_vision.load_model('pnasnet_large')
assert model.identifier == 'pnasnet_large'
8 changes: 0 additions & 8 deletions brainscore_vision/models/pnasnet_large_pytorch/test.py

This file was deleted.

0 comments on commit ad0dd5d

Please sign in to comment.