Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

comparing HF vs FA2 llama2 models #475

Open
tmm1 opened this issue Aug 22, 2023 · 26 comments
Open

comparing HF vs FA2 llama2 models #475

tmm1 opened this issue Aug 22, 2023 · 26 comments

Comments

@tmm1
Copy link
Contributor

tmm1 commented Aug 22, 2023

hi, i'm looking over the optimizations in the trainer here, and trying to port them to the transformers.trainer.Trainer for use with llama2

i put together this simple script to view the differences between the two:

from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.llama import llama_config_to_gpt2_config
from transformers import LlamaConfig, LlamaForCausalLM

MODEL = "meta-llama/Llama-2-7b-chat-hf"

config = llama_config_to_gpt2_config(LlamaConfig.from_pretrained(MODEL))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False  # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
print(config)

model = GPTLMHeadModel(config, device="meta")
print(model)

model = LlamaForCausalLM.from_pretrained(MODEL, device_map="meta")
print(model)

and I see:

GPTLMHeadModel(
  (transformer): GPTModel(
    (embeddings): GPT2Embeddings(
      (word_embeddings): Embedding(32000, 4096)
    )
    (layers): ModuleList(
      (0-31): 32 x Block(
        (mixer): MHA(
          (rotary_emb): RotaryEmbedding()
          (Wqkv): FusedDense(in_features=4096, out_features=12288, bias=False)
          (inner_attn): FlashSelfAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
          (inner_cross_attn): FlashCrossAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
          (out_proj): FusedDense(in_features=4096, out_features=4096, bias=False)
        )
        (dropout1): Dropout(p=0.0, inplace=False)
        (drop_path1): StochasticDepth(p=0.0, mode=row)
        (norm1): RMSNorm()
        (mlp): GatedMlp(
          (fc1): Linear(in_features=4096, out_features=22016, bias=False)
          (fc2): Linear(in_features=11008, out_features=4096, bias=False)
        )
        (dropout2): Dropout(p=0.0, inplace=False)
        (drop_path2): StochasticDepth(p=0.0, mode=row)
        (norm2): RMSNorm()
      )
    )
    (drop_f): Dropout(p=0.0, inplace=False)
    (ln_f): RMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
LlamaForCausalLM(                                                                                                          
  (model): LlamaModel(                                                                                                     
    (embed_tokens): Embedding(32000, 4096)                                                                                 
    (layers): ModuleList(                                                                                                  
      (0-31): 32 x LlamaDecoderLayer(                                                                                      
        (self_attn): LlamaAttention(                                                                                       
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)                                                
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)                                                
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)                                                
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)                                                
          (rotary_emb): LlamaRotaryEmbedding()                                                                             
        )                                                                                                                  
        (mlp): LlamaMLP(                                                                                                   
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)                                            
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)                                              
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)                                            
          (act_fn): SiLUActivation()                                                                                       
        )                                                                                                                  
        (input_layernorm): LlamaRMSNorm()                                                                                  
        (post_attention_layernorm): LlamaRMSNorm()                                                                         
      )                                                                                                                    
    )                                                                                                                      
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

i'm trying to understand the differences here.

it appears there's an extra RMSNorm inserted between the attn and mlp? was this intentional?

it also looks like GatedMlp only has two linear layers, but the first is double sized. what's going on there?

i saw there's no fused version of GatedMlp yet. is there any reason to use it over LLamaMLP if i'm not doing tensor parallel?

@tridao
Copy link
Contributor

tridao commented Aug 22, 2023

We have code to convert weights from Meta and HF to be compatible with the implementation in this repo.
Test is here to verify the the models implemented in this repo matches that of HF implementation.

input_layernorm -> norm1, post_attention_layernorm -> norm2
mlp.gate_proj and mlp.up_proj are combined into 1 matrix.

@tmm1
Copy link
Contributor Author

tmm1 commented Aug 23, 2023

thanks @tridao, and sorry for the dumb questions.

i see now that the order of the nodes in the print output is not relevant, as the code dictates how they're wired together.

my goal is to use huggingface/transformers.trainer.Trainer on llama2, but using the optimizations found here.

i realized that some of the optimizations are standalone, and can be integrated into the transformers llama model directly. for example, xentropy and rmsnorm: huggingface/transformers@main...tmm1:transformers:llama-flash

however other optimizations such as rotary_emb require more structural changes and are simplest to use with Block and MHA directly.

so i've tried to take the GPTLMHeadModel and feed it into transformers.Trainer directly (using #479). there are a few other minor incompatibilities (missing model.device, unimplemented model.gradient_checkpointing_enable(), etc). but after working through those, i end up with this confusing exception:

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../transformers/src/transformers/trainer.py:1555: in train
    return inner_training_loop(
../transformers/src/transformers/trainer.py:1837: in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
../transformers/src/transformers/trainer.py:2693: in training_step
    self.accelerator.backward(loss)
/home/tmm1/micromamba/envs/dev/lib/python3.10/site-packages/accelerate/accelerator.py:1902: in backward
    loss.backward(**kwargs)
/home/tmm1/micromamba/envs/dev/lib/python3.10/site-packages/torch/_tensor.py:487: in backward
    torch.autograd.backward(
/home/tmm1/micromamba/envs/dev/lib/python3.10/site-packages/torch/autograd/__init__.py:193: in backward
    grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

outputs = (tensor([[[ 0.2080,  0.0364,  0.2754,  ...,  1.4062,  1.9844,  0.7070],
         [-8.0000, -3.2188, -2.5000,  ..., -5....      [-8.7500, -9.1875,  3.5938,  ..., -6.8125, -6.8438, -5.2188]]],
       device='cuda:0', grad_fn=<DivBackward0>),)
grads = (None,), is_grads_batched = False

    def _make_grads(outputs: Sequence[torch.Tensor], grads: Sequence[_OptionalTensor],
                    is_grads_batched: bool) -> Tuple[_OptionalTensor, ...]:
        new_grads: List[_OptionalTensor] = []
        for out, grad in zip(outputs, grads):
            if isinstance(grad, torch.Tensor):
                first_grad = grad if not is_grads_batched else grad[0]
                if not torch.is_same_size(out, first_grad):
                    out_shape, grad_shape = _calculate_shape(out, first_grad, is_grads_batched)
                    if is_grads_batched:
                        raise RuntimeError("If `is_grads_batched=True`, we interpret the first "
                                           "dimension of each grad_output as the batch dimension. "
                                           "The sizes of the remaining dimensions are expected to match "
                                           "the shape of corresponding output, but a mismatch "
                                           "was detected: grad_output["
                                           + str(grads.index(grad)) + "] has a shape of "
                                           + str(grad_shape) + " and output["
                                           + str(outputs.index(out)) + "] has a shape of "
                                           + str(out_shape) + ". "
                                           "If you only want some tensors in `grad_output` to be considered "
                                           "batched, consider using vmap.")
                    else:
                        raise RuntimeError("Mismatch in shape: grad_output["
                                           + str(grads.index(grad)) + "] has a shape of "
                                           + str(grad_shape) + " and output["
                                           + str(outputs.index(out)) + "] has a shape of "
                                           + str(out_shape) + ".")
                if out.dtype.is_complex != grad.dtype.is_complex:
                    raise RuntimeError("For complex Tensors, both grad_output and output"
                                       " are required to have the same dtype."
                                       " Mismatch in dtype: grad_output["
                                       + str(grads.index(grad)) + "] has a dtype of "
                                       + str(grad.dtype) + " and output["
                                       + str(outputs.index(out)) + "] has a dtype of "
                                       + str(out.dtype) + ".")
                new_grads.append(grad)
            elif grad is None:
                if out.requires_grad:
                    if out.numel() != 1:
>                       raise RuntimeError("grad can be implicitly created only for scalar outputs")
E                       RuntimeError: grad can be implicitly created only for scalar outputs

does this make any sense to you, or do you have ideas for where i can go next to investigate?

@tridao
Copy link
Contributor

tridao commented Aug 23, 2023

My guess is that it's because our GPTLMHeadModel doesn't return a loss, it returns the output which is of size (batch, seqlen, vocab_size). You'd need to have a separate loss function (e.g. CrossEntropy).

@lxuechen
Copy link
Contributor

lxuechen commented Aug 24, 2023

FWIW I have this model def'n integrated into our training code based on huggingface trainer. I can confirm that you need to override compute_loss (ideally using the fused cross entropy loss implemented in this codebase).

Also for sequence parallel, you may want to apply allreduce_sequence_parallel_grad after backward.

@winglian
Copy link

FWIW I have this model def'n integrated into our training code based on huggingface trainer. I can confirm that you need to override compute_loss (ideally using the fused cross entropy loss implemented in this codebase).

Also for sequence parallel, you may want to apply allreduce_sequence_parallel_grad after backward.

could you share what the override for compute_loss would look like to use this?

@tmm1
Copy link
Contributor Author

tmm1 commented Aug 30, 2023

I've been comparing torch.nn.CrossEntropyLoss vs flash_attn.losses.cross_entropy.CrossEntropyLoss, but am not able to measure any memory or speed difference between the two.

@tmm1
Copy link
Contributor Author

tmm1 commented Aug 30, 2023

Aha I just saw this:

partial(CrossEntropyLoss, inplace_backward=True)

@tmm1
Copy link
Contributor Author

tmm1 commented Aug 30, 2023

I'm still not able to measure any difference. I'm using the HF trainer and model with this change:

import transformers
from functools import partial
from flash_attn.losses.cross_entropy import CrossEntropyLoss
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(CrossEntropyLoss, inplace_backward=True)

@tridao
Copy link
Contributor

tridao commented Aug 30, 2023

I'm still not able to measure any difference. I'm using the HF trainer and model with this change:

import transformers
from functools import partial
from flash_attn.losses.cross_entropy import CrossEntropyLoss
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(CrossEntropyLoss, inplace_backward=True)

It all depends on how much time / memory the cross entropy is taking. You can benchmark to see how much CE is taking. If it's taking like 1-2% of the time then it doesn't matter. For small models and/or large vocab size you'll see speedup and memory saving.

@tmm1
Copy link
Contributor Author

tmm1 commented Sep 4, 2023

I discovered my patching code wasn't running for some silly reason.

I'm interested in quantifying the differences between these implementations, especially when it comes to VRAM usage.

I used the VRAM line profiler I'm working on here, and measured the following difference:

- Train VRAM peak: 1.94464 GB
+ Train VRAM peak: 1.78449 GB

 Line #      Hits          Mem  Per Hit    % Mem  Line Contents                                                             
============================================================== 
-   852         2        262.1    131.1     15.1              loss = loss_fct(shift_logits, shift_labels)
+   852         2                                             loss = loss_fct(shift_logits, shift_labels)
...
-  2693         2        318.8    159.4     15.5              self.accelerator.backward(loss)
+  2693         2        408.9    204.5     21.7              self.accelerator.backward(loss)

This is a very limited test (2 steps w/ llama2-7b @ 2048 ctx), but it shows the slight benefit present when using the custom xentropy kernel.

@tridao
Copy link
Contributor

tridao commented Sep 4, 2023

Is this with batch size 1?
My back-of-the-envelop calculation: the logits has size (batch, seqlen, vocab_size), taking 2 bytes each (e.g. training with bf16).
Our xentropy kernel avoids storing an extra copy so we save (2 * batch * seqlen * vocab_size) bytes.
With llama 7b and batch=1, seqlen=2048, vocab_size = 30k, this is 123MB.
With larger batch size the memory saving is larger (but maybe you can't run with larger batch size because of GPU mem limit).

@tmm1
Copy link
Contributor Author

tmm1 commented Sep 4, 2023

Yea, that was with batchsize=1.

I made some more measurements @ ctx=4096:

cfg mem
bs=1 xentropy=false 3.59699 GB
bs=1 xentroy=true 3.28644 GB
bs=2 xentropy=false 7.27851 GB
bs=2 xentropy=true 6.65741 GB

@tridao
Copy link
Contributor

tridao commented Sep 4, 2023

Yup seems to check out with my calculation.

@tmm1
Copy link
Contributor Author

tmm1 commented Sep 4, 2023

Now I applied the rmsnorm kernel on top, as follows:

from flash_attn.ops.rms_norm import RMSNorm

class LlamaRMSNorm(RMSNorm):  
    """Patched LLamaRMSNorm"""

    def __init__(self, hidden_size, eps=1e-6):
        super().__init__(hidden_size, eps=eps)

transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm

and the results are further improved for ctx=4096:

bs xentropy rms mem
1 true false 3.28644 GB
1 true true 3.20246 GB
2 true false 6.65741 GB
2 true true 5.95893 GB
4 false false 13.9354 GB
4 true false 12.0428 GB
4 true true 11.8348 GB
6 false false 18.3586 GB
6 true true 17.6281 GB

@tmm1
Copy link
Contributor Author

tmm1 commented Sep 9, 2023

I'm working on the rotary kernel next, but am not quite sure if I'm handling the cos/sin correctly:

--- a/src/transformers/models/llama/modeling_llama.py                                                               
+++ b/src/transformers/models/llama/modeling_llama.py                                                               
@@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, MSELoss                                                   
                                                                                                                    
 from flash_attn.losses.cross_entropy import CrossEntropyLoss                                                       
 from flash_attn.ops.rms_norm import RMSNorm                                                                        
+from flash_attn.layers.rotary import apply_rotary_emb
         
 from ...activations import ACT2FN
 from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
@@ -324,7 +325,14 @@ class LlamaAttention(nn.Module):
         if past_key_value is not None:
             kv_seq_len += past_key_value[0].shape[-2]
         cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+        cos = cos.squeeze(1).squeeze(0)[position_ids].squeeze(0)
+        sin = sin.squeeze(1).squeeze(0)[position_ids].squeeze(0)
+        # cos, sin: (seqlen_rotary, rotary_dim / 2)
+        cos = cos[..., :cos.shape[-1] * 2:2]
+        sin = sin[..., :sin.shape[-1] * 2:2]
+        # fused rope expects (batch_size, seqlen, nheads, headdim)
+        query_states = apply_rotary_emb(query_states.transpose(1, 2), cos, sin, inplace=True).transpose(1, 2)
+        key_states = apply_rotary_emb(key_states.transpose(1, 2), cos, sin, inplace=True).transpose(1, 2)
  
         if past_key_value is not None:
             # reuse k, v, self_attention

@tridao
Copy link
Contributor

tridao commented Sep 9, 2023

Why not just convert the HF weights to use the Llama implementation in this repo?

def test_llama_optimized(model_name, checkpoint_format):

You can also see how we use rotary in MHA here.

@tmm1
Copy link
Contributor Author

tmm1 commented Sep 9, 2023

Thanks for the pointer.

I know I can convert the weights and use the trainer here, but I'm interested in features that transformers offers out of the box, such as loading weights in 4bit/8bit and using PEFT techniques such as LoRA, QLoRA and IA3. I'm also just trying to understand how this stuff works better, so translating between the two implementations is helpful as a learning exercise.

I'll read through the MHA implementation and see what I can figure out.

@tmm1
Copy link
Contributor Author

tmm1 commented Sep 9, 2023

Okay I see I should probably be using the flash_attn.layers.rotary.RotaryEmbedding module instead of trying to call apply_rotary_emb directly.

On the transformers side there are several variations:

class LlamaRotaryEmbedding(torch.nn.Module):
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):

I'll see what it takes to get these variations into the fa2 rotary module first

@tridao
Copy link
Contributor

tridao commented Sep 9, 2023

Yeah we should probably add those as arguments to the RotaryEmbedding module (scaling factor, scaling_method="ntk" or scaling_method="standard"). Is "standard" the right name or is there another name?

@MrigankRaman
Copy link

Thanks for the pointer.

I know I can convert the weights and use the trainer here, but I'm interested in features that transformers offers out of the box, such as loading weights in 4bit/8bit and using PEFT techniques such as LoRA, QLoRA and IA3. I'm also just trying to understand how this stuff works better, so translating between the two implementations is helpful as a learning exercise.

I'll read through the MHA implementation and see what I can figure out.

Have you been able to look at this? I was also wondering how would one use the fused MLP layers with huggingface?

@tmm1
Copy link
Contributor Author

tmm1 commented Sep 24, 2023

You can use this subclass with the HF trainer: #486

@MrigankRaman
Copy link

You can use this subclass with the HF trainer: #486

Can it be used in 4bit and peft like hf models?

@tmm1
Copy link
Contributor Author

tmm1 commented Sep 24, 2023

Yes, but not with fused MLP because there's no place for peft to hook into the linear layers.

@MrigankRaman
Copy link

Yes, but not with fused MLP because there's no place for peft to hook into the linear layers.

Thanks a lot. Can you give a pointer how to load that model in 4bit using huggingface?

@MrigankRaman
Copy link

Yes, but not with fused MLP because there's no place for peft to hook into the linear layers.

I dont think we can use that model in 4bit

@980202006
Copy link

980202006 commented Jul 16, 2024

Is there any information about training speed comparing HF vs FA2 llama2 models ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants