Skip to content

Commit

Permalink
Added gpt2-medium model and unit tests (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
shehadak authored Nov 17, 2023
1 parent 3d873e0 commit 8b5e95a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
3 changes: 3 additions & 0 deletions brainscore_language/models/gpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
model_registry['distilgpt2'] = lambda: HuggingfaceSubject(model_id='distilgpt2', region_layer_mapping={
ArtificialSubject.RecordingTarget.language_system: 'transformer.h.5'})

model_registry['gpt2-medium'] = lambda: HuggingfaceSubject(model_id='gpt2-medium', region_layer_mapping={
ArtificialSubject.RecordingTarget.language_system: 'transformer.h.22'})

model_registry['gpt2-large'] = lambda: HuggingfaceSubject(model_id='gpt2-large', region_layer_mapping={
ArtificialSubject.RecordingTarget.language_system: 'transformer.h.33'})

Expand Down
4 changes: 4 additions & 0 deletions brainscore_language/models/gpt/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
@pytest.mark.parametrize('model_identifier, expected_reading_times', [
('distilgpt2', [np.nan, 19.260605, 12.721411, 12.083241,
10.876629, 3.678278, 2.102749, 11.961533]),
('gpt2-medium', [np.nan, 14.88489, 6.539810, 0.08106061,
0.6542016, 0.06957269, 4.027023e-03, 4.039307e-04]),
('gpt2-large', [np.nan, 13.776375, 5.054959, 0.620946,
0.522623, 0.102953, 0.038324, 0.021452]),
('gpt2-xl', [np.nan, 1.378484e+01, 6.686095e+00, 2.284407e-01,
Expand All @@ -29,6 +31,7 @@ def test_reading_times(model_identifier, expected_reading_times):
@pytest.mark.memory_intense
@pytest.mark.parametrize('model_identifier, expected_next_words', [
('distilgpt2', ['es', 'the', 'fox']),
('gpt2-medium', ['jumps', 'the', 'dog']),
('gpt2-large', ['jumps', 'the', 'dog']),
('gpt2-xl', ['jumps', 'the', 'dog']),
('gpt-neo-2.7B', ['jumps', 'the', 'dog']),
Expand All @@ -45,6 +48,7 @@ def test_next_word(model_identifier, expected_next_words):
@pytest.mark.memory_intense
@pytest.mark.parametrize('model_identifier, feature_size', [
('distilgpt2', 768),
('gpt2-medium', 1024),
('gpt2-large', 1280),
('gpt2-xl', 1600),
('gpt-neo-1.3B', 2048),
Expand Down

0 comments on commit 8b5e95a

Please sign in to comment.