diff --git a/brainscore_language/models/gpt/__init__.py b/brainscore_language/models/gpt/__init__.py index f75d9d10..da37298c 100644 --- a/brainscore_language/models/gpt/__init__.py +++ b/brainscore_language/models/gpt/__init__.py @@ -8,6 +8,9 @@ model_registry['distilgpt2'] = lambda: HuggingfaceSubject(model_id='distilgpt2', region_layer_mapping={ ArtificialSubject.RecordingTarget.language_system: 'transformer.h.5'}) +model_registry['gpt2-large'] = lambda: HuggingfaceSubject(model_id='gpt2-large', region_layer_mapping={ + ArtificialSubject.RecordingTarget.language_system: 'transformer.h.33'}) + model_registry['gpt2-xl'] = lambda: HuggingfaceSubject(model_id='gpt2-xl', region_layer_mapping={ ArtificialSubject.RecordingTarget.language_system: 'transformer.h.43'}) diff --git a/brainscore_language/models/gpt/test.py b/brainscore_language/models/gpt/test.py index af9cbbc7..20943f5e 100644 --- a/brainscore_language/models/gpt/test.py +++ b/brainscore_language/models/gpt/test.py @@ -9,6 +9,8 @@ @pytest.mark.parametrize('model_identifier, expected_reading_times', [ ('distilgpt2', [np.nan, 19.260605, 12.721411, 12.083241, 10.876629, 3.678278, 2.102749, 11.961533]), + ('gpt2-large', [np.nan, 13.776375, 5.054959, 0.620946, + 0.522623, 0.102953, 0.038324, 0.021452]), ('gpt2-xl', [np.nan, 1.378484e+01, 6.686095e+00, 2.284407e-01, 7.538393e-01, 6.105860e-03, 2.644155e-02, 4.411311e-03]), ('gpt-neo-2.7B', [np.nan, 15.07522869, 3.6358602 , 0.04999408, 1.42219079, @@ -27,6 +29,7 @@ def test_reading_times(model_identifier, expected_reading_times): @pytest.mark.memory_intense @pytest.mark.parametrize('model_identifier, expected_next_words', [ ('distilgpt2', ['es', 'the', 'fox']), + ('gpt2-large', ['jumps', 'the', 'dog']), ('gpt2-xl', ['jumps', 'the', 'dog']), ('gpt-neo-2.7B', ['jumps', 'the', 'dog']), ('gpt-neo-1.3B', ['jumps', 'the', 'dog']) @@ -42,6 +45,7 @@ def test_next_word(model_identifier, expected_next_words): @pytest.mark.memory_intense @pytest.mark.parametrize('model_identifier, feature_size', [ ('distilgpt2', 768), + ('gpt2-large', 1280), ('gpt2-xl', 1600), ('gpt-neo-1.3B', 2048), ('gpt-neo-2.7B', 2560)