forked from Dao-AILab/flash-attention
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
209 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (c) 2023, Tri Dao. | ||
|
||
import math | ||
import re | ||
|
||
from collections import OrderedDict | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from einops import rearrange | ||
|
||
from transformers import GPT2Config, GPTNeoXConfig | ||
|
||
|
||
def remap_state_dict_hf_gpt_neox(state_dict, config): | ||
def key_mapping_layers(key): | ||
return re.sub(r'^gpt_neox.', 'transformer.', key) | ||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) | ||
# Word embedding | ||
def key_mapping_emb(key): | ||
return re.sub(r'^transformer.embed_in.', 'transformer.embeddings.word_embeddings.', key) | ||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) | ||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') | ||
# It's possible that vocab_size is padded to be a multiple of 8, for example. | ||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) | ||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) | ||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( | ||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) | ||
) | ||
if getattr(config, 'tie_word_embeddings'): | ||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] | ||
else: | ||
output_embeddings = state_dict.pop('embed_out.weight') | ||
# It's possible that vocab_size is padded to be a multiple of 8, for example. | ||
state_dict['lm_head.weight'] = F.pad( | ||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) | ||
) | ||
|
||
# LayerNorm | ||
def key_mapping_ln(key): | ||
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key) | ||
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key) | ||
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key) | ||
return key | ||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) | ||
|
||
# MLP | ||
def key_mapping_mlp(key): | ||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.', r'transformer.layers.\1.mlp.fc1.', key) | ||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.', r'transformer.layers.\1.mlp.fc2.', key) | ||
return key | ||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) | ||
|
||
# Attention | ||
for l in range(config.n_layer): | ||
# We don't store these biases | ||
state_dict.pop(f'transformer.layers.{l}.attention.bias') | ||
state_dict.pop(f'transformer.layers.{l}.attention.masked_bias') | ||
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim) | ||
# while we store Wqkv as ((3 nheads headdim), hidden_dim) | ||
headdim = config.hidden_size // config.num_attention_heads | ||
Wqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.weight') | ||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = rearrange( | ||
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...', | ||
three=3, headdim=headdim | ||
) | ||
bqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.bias') | ||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = rearrange( | ||
bqkv, '(nheads three headdim) -> (three nheads headdim)', | ||
three=3, headdim=headdim | ||
) | ||
def key_mapping_attn(key): | ||
key = re.sub(r'^transformer.layers.(\d+).attention.dense.', | ||
r'transformer.layers.\1.mixer.out_proj.', key) | ||
key = re.sub(r'^transformer.layers.(\d+).attention.rotary_emb.', | ||
r'transformer.layers.\1.mixer.rotary_emb.', key) | ||
return key | ||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) | ||
|
||
return state_dict | ||
|
||
|
||
def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config: | ||
assert gpt_neox_config.rotary_emb_base == 10000 | ||
return GPT2Config( | ||
vocab_size=gpt_neox_config.vocab_size, | ||
n_positions=0, # No absolute position embedding | ||
n_embd=gpt_neox_config.hidden_size, | ||
n_layer=gpt_neox_config.num_hidden_layers, | ||
n_head=gpt_neox_config.num_attention_heads, | ||
n_inner=gpt_neox_config.intermediate_size, | ||
activation_function=gpt_neox_config.hidden_act, | ||
resid_pdrop=0.0, # No dropout | ||
embd_pdrop=0.0, | ||
attn_pdrop=0.0, | ||
layer_norm_epsilon=gpt_neox_config.layer_norm_eps, | ||
initializer_range=gpt_neox_config.initializer_range, | ||
bos_token_id=gpt_neox_config.bos_token_id, | ||
eos_token_id=gpt_neox_config.eos_token_id, | ||
# These are new arguments not in the original GPT2Config | ||
prenorm=True, | ||
parallel_block=gpt_neox_config.use_parallel_residual, | ||
parallel_block_tied_norm=False, | ||
rotary_emb_fraction=gpt_neox_config.rotary_pct, | ||
tie_word_embeddings=gpt_neox_config.tie_word_embeddings, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import time | ||
|
||
import torch | ||
import pytest | ||
|
||
from transformers import GPTNeoXConfig, AutoTokenizer | ||
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM | ||
|
||
from flash_attn.models.gpt import GPTLMHeadModel | ||
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config | ||
from flash_attn.utils.pretrained import state_dict_from_pretrained | ||
from flash_attn.utils.generation import update_graph_cache | ||
|
||
|
||
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"]) | ||
def test_gptj_state_dict(model_name): | ||
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name)) | ||
pretrained_state_dict = remap_state_dict_hf_gpt_neox(state_dict_from_pretrained(model_name), config) | ||
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow | ||
state_dict = model.state_dict() | ||
assert state_dict.keys() == pretrained_state_dict.keys() | ||
for k in state_dict.keys(): | ||
assert state_dict[k].shape == pretrained_state_dict[k].shape | ||
|
||
|
||
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"]) | ||
def test_gpt_neox_optimized(model_name): | ||
"""Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the | ||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF | ||
forward pass in fp16, when compared to the HF forward pass in fp32. | ||
""" | ||
dtype = torch.float16 | ||
device = 'cuda' | ||
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name)) | ||
config.use_flash_attn = True | ||
config.fused_bias_fc = True | ||
config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast" | ||
config.fused_dropout_add_ln = False # We don't support parallel block yet | ||
config.residual_in_fp32 = True | ||
|
||
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) | ||
model.eval() | ||
|
||
torch.manual_seed(0) | ||
batch_size = 2 | ||
max_seqlen = 256 | ||
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) | ||
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, | ||
device=device) | ||
with torch.no_grad(): | ||
out = model.transformer(input_ids) | ||
logits = model(input_ids).logits | ||
del model | ||
|
||
# Need at least 2 GPUs, otherwise we'll OOM | ||
# Without device_map, the model is loaded on the CPU, which is very slow | ||
model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map='auto') | ||
model_ref.eval() | ||
with torch.no_grad(): | ||
out_ref = model_ref.gpt_neox(input_ids).last_hidden_state.to(device=device) | ||
logits_ref = model_ref(input_ids).logits.to(device=device) | ||
del model_ref | ||
|
||
model_hf = GPTNeoXForCausalLM.from_pretrained(model_name, torch_dtype=dtype, | ||
device_map={"": device}) | ||
model_hf.eval() | ||
with torch.no_grad(): | ||
out_hf = model_hf.gpt_neox(input_ids).last_hidden_state | ||
logits_hf = model_hf(input_ids).logits | ||
del model_hf | ||
|
||
print(f'Output max diff: {(out - out_ref).abs().max().item()}') | ||
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') | ||
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') | ||
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') | ||
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() | ||
assert (out - out_ref).abs().mean().item() < 2 * (out_hf - out_ref).abs().mean().item() | ||
|
||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') | ||
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') | ||
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') | ||
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') | ||
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item() | ||
assert (logits - logits_ref).abs().mean().item() < 2 * (logits_hf - logits_ref).abs().mean().item() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters