From 993d12448e2af5fe73bad1c8f93a3cb524aade33 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Mar 2023 01:21:25 -0700 Subject: [PATCH] Implement GPT-NeoX --- flash_attn/models/gpt.py | 9 ++- flash_attn/models/gpt_neox.py | 107 ++++++++++++++++++++++++++++++++++ flash_attn/models/gptj.py | 5 ++ tests/models/test_gpt_neox.py | 84 ++++++++++++++++++++++++++ tests/models/test_gptj.py | 13 +++-- 5 files changed, 209 insertions(+), 9 deletions(-) create mode 100644 flash_attn/models/gpt_neox.py create mode 100644 tests/models/test_gpt_neox.py diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 566697b4a..fddaddd92 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -25,6 +25,7 @@ from flash_attn.utils.generation import GenerationMixin from flash_attn.models.opt import remap_state_dict_hf_opt from flash_attn.models.gptj import remap_state_dict_hf_gptj +from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox try: from flash_attn.ops.fused_dense import ColumnParallelLinear @@ -205,6 +206,8 @@ def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dt elif model_name.startswith('EleutherAI/gpt-j-'): state_dict = remap_state_dict_hf_gptj(state_dict, config) strict = False # We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint + elif model_name.startswith('EleutherAI/gpt-neox-'): + state_dict = remap_state_dict_hf_gpt_neox(state_dict, config) else: raise NotImplementedError(f'Model {model_name} not supported') if world_size > 1: @@ -355,6 +358,7 @@ def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=No self.process_group = process_group self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs) self.tie_word_embeddings = getattr(config, 'tie_word_embeddings', True) + lm_head_bias = getattr(config, 'lm_head_bias', False) pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) @@ -366,13 +370,12 @@ def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=No else: self.project_out = None if process_group is None: - self.lm_head = nn.Linear(embed_dim, vocab_size, bias=not self.tie_word_embeddings, - **factory_kwargs) + self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs) else: if ColumnParallelLinear is None: raise ImportError('fused_dense_lib is not installed') self.lm_head = ColumnParallelLinear( - embed_dim, vocab_size, process_group, bias=not self.tie_word_embeddings, + embed_dim, vocab_size, process_group, bias=lm_head_bias, sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs ) # Initialize weights and apply final processing diff --git a/flash_attn/models/gpt_neox.py b/flash_attn/models/gpt_neox.py new file mode 100644 index 000000000..9f23387a7 --- /dev/null +++ b/flash_attn/models/gpt_neox.py @@ -0,0 +1,107 @@ +# Copyright (c) 2023, Tri Dao. + +import math +import re + +from collections import OrderedDict + +import torch +import torch.nn.functional as F + +from einops import rearrange + +from transformers import GPT2Config, GPTNeoXConfig + + +def remap_state_dict_hf_gpt_neox(state_dict, config): + def key_mapping_layers(key): + return re.sub(r'^gpt_neox.', 'transformer.', key) + state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + # Word embedding + def key_mapping_emb(key): + return re.sub(r'^transformer.embed_in.', 'transformer.embeddings.word_embeddings.', key) + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + if getattr(config, 'tie_word_embeddings'): + state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] + else: + output_embeddings = state_dict.pop('embed_out.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + state_dict['lm_head.weight'] = F.pad( + output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) + ) + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key) + key = re.sub(r'^transformer.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key) + key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key) + return key + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.', r'transformer.layers.\1.mlp.fc1.', key) + key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.', r'transformer.layers.\1.mlp.fc2.', key) + return key + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for l in range(config.n_layer): + # We don't store these biases + state_dict.pop(f'transformer.layers.{l}.attention.bias') + state_dict.pop(f'transformer.layers.{l}.attention.masked_bias') + # GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim) + # while we store Wqkv as ((3 nheads headdim), hidden_dim) + headdim = config.hidden_size // config.num_attention_heads + Wqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.weight') + state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = rearrange( + Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...', + three=3, headdim=headdim + ) + bqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.bias') + state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = rearrange( + bqkv, '(nheads three headdim) -> (three nheads headdim)', + three=3, headdim=headdim + ) + def key_mapping_attn(key): + key = re.sub(r'^transformer.layers.(\d+).attention.dense.', + r'transformer.layers.\1.mixer.out_proj.', key) + key = re.sub(r'^transformer.layers.(\d+).attention.rotary_emb.', + r'transformer.layers.\1.mixer.rotary_emb.', key) + return key + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config: + assert gpt_neox_config.rotary_emb_base == 10000 + return GPT2Config( + vocab_size=gpt_neox_config.vocab_size, + n_positions=0, # No absolute position embedding + n_embd=gpt_neox_config.hidden_size, + n_layer=gpt_neox_config.num_hidden_layers, + n_head=gpt_neox_config.num_attention_heads, + n_inner=gpt_neox_config.intermediate_size, + activation_function=gpt_neox_config.hidden_act, + resid_pdrop=0.0, # No dropout + embd_pdrop=0.0, + attn_pdrop=0.0, + layer_norm_epsilon=gpt_neox_config.layer_norm_eps, + initializer_range=gpt_neox_config.initializer_range, + bos_token_id=gpt_neox_config.bos_token_id, + eos_token_id=gpt_neox_config.eos_token_id, + # These are new arguments not in the original GPT2Config + prenorm=True, + parallel_block=gpt_neox_config.use_parallel_residual, + parallel_block_tied_norm=False, + rotary_emb_fraction=gpt_neox_config.rotary_pct, + tie_word_embeddings=gpt_neox_config.tie_word_embeddings, + ) diff --git a/flash_attn/models/gptj.py b/flash_attn/models/gptj.py index ff859989f..8a6c6f242 100644 --- a/flash_attn/models/gptj.py +++ b/flash_attn/models/gptj.py @@ -34,6 +34,10 @@ def key_mapping_emb(key): state_dict['lm_head.weight'] = F.pad( output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) ) + output_embeddings_bias = state_dict.pop('lm_head.bias') + state_dict['lm_head.bias'] = F.pad( + output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0]) + ) # LayerNorm def key_mapping_ln(key): @@ -92,4 +96,5 @@ def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config: tie_word_embeddings=False, qkv_proj_bias=False, out_proj_bias=False, + lm_head_bias=True, ) diff --git a/tests/models/test_gpt_neox.py b/tests/models/test_gpt_neox.py new file mode 100644 index 000000000..2a7f74f86 --- /dev/null +++ b/tests/models/test_gpt_neox.py @@ -0,0 +1,84 @@ +import time + +import torch +import pytest + +from transformers import GPTNeoXConfig, AutoTokenizer +from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM + +from flash_attn.models.gpt import GPTLMHeadModel +from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config +from flash_attn.utils.pretrained import state_dict_from_pretrained +from flash_attn.utils.generation import update_graph_cache + + +@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"]) +def test_gptj_state_dict(model_name): + config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name)) + pretrained_state_dict = remap_state_dict_hf_gpt_neox(state_dict_from_pretrained(model_name), config) + model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow + state_dict = model.state_dict() + assert state_dict.keys() == pretrained_state_dict.keys() + for k in state_dict.keys(): + assert state_dict[k].shape == pretrained_state_dict[k].shape + + +@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"]) +def test_gpt_neox_optimized(model_name): + """Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the + HF implementation: the output of our forward pass in fp16 should be around the same as the HF + forward pass in fp16, when compared to the HF forward pass in fp32. + """ + dtype = torch.float16 + device = 'cuda' + config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name)) + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast" + config.fused_dropout_add_ln = False # We don't support parallel block yet + config.residual_in_fp32 = True + + model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) + model.eval() + + torch.manual_seed(0) + batch_size = 2 + max_seqlen = 256 + seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) + input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, + device=device) + with torch.no_grad(): + out = model.transformer(input_ids) + logits = model(input_ids).logits + del model + + # Need at least 2 GPUs, otherwise we'll OOM + # Without device_map, the model is loaded on the CPU, which is very slow + model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map='auto') + model_ref.eval() + with torch.no_grad(): + out_ref = model_ref.gpt_neox(input_ids).last_hidden_state.to(device=device) + logits_ref = model_ref(input_ids).logits.to(device=device) + del model_ref + + model_hf = GPTNeoXForCausalLM.from_pretrained(model_name, torch_dtype=dtype, + device_map={"": device}) + model_hf.eval() + with torch.no_grad(): + out_hf = model_hf.gpt_neox(input_ids).last_hidden_state + logits_hf = model_hf(input_ids).logits + del model_hf + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') + print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') + assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() + assert (out - out_ref).abs().mean().item() < 2 * (out_hf - out_ref).abs().mean().item() + + print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') + print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') + print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') + print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') + assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item() + assert (logits - logits_ref).abs().mean().item() < 2 * (logits_hf - logits_ref).abs().mean().item() diff --git a/tests/models/test_gptj.py b/tests/models/test_gptj.py index 8e9b6df6b..0ef0ee8b7 100644 --- a/tests/models/test_gptj.py +++ b/tests/models/test_gptj.py @@ -3,7 +3,7 @@ import torch import pytest -from transformers import GPTJConfig +from transformers import GPTJConfig, AutoTokenizer from transformers.models.gptj.modeling_gptj import GPTJForCausalLM from flash_attn.models.gpt import GPTLMHeadModel @@ -37,7 +37,6 @@ def test_gptj_optimized(model_name): config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = False # We don't support parallel block yet - # Only prenorm supports residual_in_fp32 config.residual_in_fp32 = True model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) @@ -46,22 +45,24 @@ def test_gptj_optimized(model_name): torch.manual_seed(0) batch_size = 2 max_seqlen = 256 - seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') + seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, - device='cuda') + device=device) with torch.no_grad(): out = model.transformer(input_ids) logits = model(input_ids).logits del model - model_ref = GPTJForCausalLM.from_pretrained(model_name).to(device=device) + # Without device_map, the model is loaded on the CPU, which is very slow + model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device}) model_ref.eval() with torch.no_grad(): out_ref = model_ref.transformer(input_ids).last_hidden_state logits_ref = model_ref(input_ids).logits del model_ref - model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device) + model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype, + device_map={"": device}) model_hf.eval() out_hf = model_hf.transformer(input_ids).last_hidden_state logits_hf = model_hf(input_ids).logits