Skip to content

Commit

Permalink
Implement GPT-NeoX
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Mar 29, 2023
1 parent f5d0fbd commit 993d124
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 9 deletions.
9 changes: 6 additions & 3 deletions flash_attn/models/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from flash_attn.utils.generation import GenerationMixin
from flash_attn.models.opt import remap_state_dict_hf_opt
from flash_attn.models.gptj import remap_state_dict_hf_gptj
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox

try:
from flash_attn.ops.fused_dense import ColumnParallelLinear
Expand Down Expand Up @@ -205,6 +206,8 @@ def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dt
elif model_name.startswith('EleutherAI/gpt-j-'):
state_dict = remap_state_dict_hf_gptj(state_dict, config)
strict = False # We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint
elif model_name.startswith('EleutherAI/gpt-neox-'):
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
else:
raise NotImplementedError(f'Model {model_name} not supported')
if world_size > 1:
Expand Down Expand Up @@ -355,6 +358,7 @@ def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=No
self.process_group = process_group
self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
self.tie_word_embeddings = getattr(config, 'tie_word_embeddings', True)
lm_head_bias = getattr(config, 'lm_head_bias', False)
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
Expand All @@ -366,13 +370,12 @@ def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=No
else:
self.project_out = None
if process_group is None:
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=not self.tie_word_embeddings,
**factory_kwargs)
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
else:
if ColumnParallelLinear is None:
raise ImportError('fused_dense_lib is not installed')
self.lm_head = ColumnParallelLinear(
embed_dim, vocab_size, process_group, bias=not self.tie_word_embeddings,
embed_dim, vocab_size, process_group, bias=lm_head_bias,
sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
)
# Initialize weights and apply final processing
Expand Down
107 changes: 107 additions & 0 deletions flash_attn/models/gpt_neox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2023, Tri Dao.

import math
import re

from collections import OrderedDict

import torch
import torch.nn.functional as F

from einops import rearrange

from transformers import GPT2Config, GPTNeoXConfig


def remap_state_dict_hf_gpt_neox(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^gpt_neox.', 'transformer.', key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.embed_in.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('embed_out.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)

# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())

# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.', r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.', r'transformer.layers.\1.mlp.fc2.', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

# Attention
for l in range(config.n_layer):
# We don't store these biases
state_dict.pop(f'transformer.layers.{l}.attention.bias')
state_dict.pop(f'transformer.layers.{l}.attention.masked_bias')
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim = config.hidden_size // config.num_attention_heads
Wqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = rearrange(
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
three=3, headdim=headdim
)
bqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.bias')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = rearrange(
bqkv, '(nheads three headdim) -> (three nheads headdim)',
three=3, headdim=headdim
)
def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).attention.dense.',
r'transformer.layers.\1.mixer.out_proj.', key)
key = re.sub(r'^transformer.layers.(\d+).attention.rotary_emb.',
r'transformer.layers.\1.mixer.rotary_emb.', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

return state_dict


def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config:
assert gpt_neox_config.rotary_emb_base == 10000
return GPT2Config(
vocab_size=gpt_neox_config.vocab_size,
n_positions=0, # No absolute position embedding
n_embd=gpt_neox_config.hidden_size,
n_layer=gpt_neox_config.num_hidden_layers,
n_head=gpt_neox_config.num_attention_heads,
n_inner=gpt_neox_config.intermediate_size,
activation_function=gpt_neox_config.hidden_act,
resid_pdrop=0.0, # No dropout
embd_pdrop=0.0,
attn_pdrop=0.0,
layer_norm_epsilon=gpt_neox_config.layer_norm_eps,
initializer_range=gpt_neox_config.initializer_range,
bos_token_id=gpt_neox_config.bos_token_id,
eos_token_id=gpt_neox_config.eos_token_id,
# These are new arguments not in the original GPT2Config
prenorm=True,
parallel_block=gpt_neox_config.use_parallel_residual,
parallel_block_tied_norm=False,
rotary_emb_fraction=gpt_neox_config.rotary_pct,
tie_word_embeddings=gpt_neox_config.tie_word_embeddings,
)
5 changes: 5 additions & 0 deletions flash_attn/models/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def key_mapping_emb(key):
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
output_embeddings_bias = state_dict.pop('lm_head.bias')
state_dict['lm_head.bias'] = F.pad(
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
)

# LayerNorm
def key_mapping_ln(key):
Expand Down Expand Up @@ -92,4 +96,5 @@ def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:
tie_word_embeddings=False,
qkv_proj_bias=False,
out_proj_bias=False,
lm_head_bias=True,
)
84 changes: 84 additions & 0 deletions tests/models/test_gpt_neox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import time

import torch
import pytest

from transformers import GPTNeoXConfig, AutoTokenizer
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM

from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache


@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"])
def test_gptj_state_dict(model_name):
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_gpt_neox(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape


@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"])
def test_gpt_neox_optimized(model_name):
"""Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
config.fused_dropout_add_ln = False # We don't support parallel block yet
config.residual_in_fp32 = True

model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()

torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model

# Need at least 2 GPUs, otherwise we'll OOM
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map='auto')
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.gpt_neox(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref

model_hf = GPTNeoXForCausalLM.from_pretrained(model_name, torch_dtype=dtype,
device_map={"": device})
model_hf.eval()
with torch.no_grad():
out_hf = model_hf.gpt_neox(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf

print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
assert (out - out_ref).abs().mean().item() < 2 * (out_hf - out_ref).abs().mean().item()

print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
assert (logits - logits_ref).abs().mean().item() < 2 * (logits_hf - logits_ref).abs().mean().item()
13 changes: 7 additions & 6 deletions tests/models/test_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import pytest

from transformers import GPTJConfig
from transformers import GPTJConfig, AutoTokenizer
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM

from flash_attn.models.gpt import GPTLMHeadModel
Expand Down Expand Up @@ -37,7 +37,6 @@ def test_gptj_optimized(model_name):
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = False # We don't support parallel block yet
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = True

model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
Expand All @@ -46,22 +45,24 @@ def test_gptj_optimized(model_name):
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
device=device)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model

model_ref = GPTJForCausalLM.from_pretrained(model_name).to(device=device)
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.transformer(input_ids).last_hidden_state
logits_ref = model_ref(input_ids).logits
del model_ref

model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype,
device_map={"": device})
model_hf.eval()
out_hf = model_hf.transformer(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
Expand Down

0 comments on commit 993d124

Please sign in to comment.