-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
comparing HF vs FA2 llama2 models #475
Comments
We have code to convert weights from Meta and HF to be compatible with the implementation in this repo. input_layernorm -> norm1, post_attention_layernorm -> norm2 |
thanks @tridao, and sorry for the dumb questions. i see now that the order of the nodes in the print output is not relevant, as the code dictates how they're wired together. my goal is to use i realized that some of the optimizations are standalone, and can be integrated into the transformers llama model directly. for example, xentropy and rmsnorm: huggingface/transformers@main...tmm1:transformers:llama-flash however other optimizations such as rotary_emb require more structural changes and are simplest to use with so i've tried to take the
does this make any sense to you, or do you have ideas for where i can go next to investigate? |
My guess is that it's because our |
FWIW I have this model def'n integrated into our training code based on huggingface trainer. I can confirm that you need to override Also for sequence parallel, you may want to apply |
could you share what the override for |
I've been comparing |
Aha I just saw this:
|
I'm still not able to measure any difference. I'm using the HF trainer and model with this change: import transformers
from functools import partial
from flash_attn.losses.cross_entropy import CrossEntropyLoss
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(CrossEntropyLoss, inplace_backward=True) |
It all depends on how much time / memory the cross entropy is taking. You can benchmark to see how much CE is taking. If it's taking like 1-2% of the time then it doesn't matter. For small models and/or large vocab size you'll see speedup and memory saving. |
I discovered my patching code wasn't running for some silly reason. I'm interested in quantifying the differences between these implementations, especially when it comes to VRAM usage. I used the VRAM line profiler I'm working on here, and measured the following difference: - Train VRAM peak: 1.94464 GB
+ Train VRAM peak: 1.78449 GB
Line # Hits Mem Per Hit % Mem Line Contents
==============================================================
- 852 2 262.1 131.1 15.1 loss = loss_fct(shift_logits, shift_labels)
+ 852 2 loss = loss_fct(shift_logits, shift_labels)
...
- 2693 2 318.8 159.4 15.5 self.accelerator.backward(loss)
+ 2693 2 408.9 204.5 21.7 self.accelerator.backward(loss) This is a very limited test (2 steps w/ llama2-7b @ 2048 ctx), but it shows the slight benefit present when using the custom xentropy kernel. |
Is this with batch size 1? |
Yea, that was with batchsize=1. I made some more measurements @ ctx=4096:
|
Yup seems to check out with my calculation. |
Now I applied the rmsnorm kernel on top, as follows: from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm and the results are further improved for ctx=4096:
|
I'm working on the rotary kernel next, but am not quite sure if I'm handling the cos/sin correctly: --- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, MSELoss
from flash_attn.losses.cross_entropy import CrossEntropyLoss
from flash_attn.ops.rms_norm import RMSNorm
+from flash_attn.layers.rotary import apply_rotary_emb
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
@@ -324,7 +325,14 @@ class LlamaAttention(nn.Module):
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ cos = cos.squeeze(1).squeeze(0)[position_ids].squeeze(0)
+ sin = sin.squeeze(1).squeeze(0)[position_ids].squeeze(0)
+ # cos, sin: (seqlen_rotary, rotary_dim / 2)
+ cos = cos[..., :cos.shape[-1] * 2:2]
+ sin = sin[..., :sin.shape[-1] * 2:2]
+ # fused rope expects (batch_size, seqlen, nheads, headdim)
+ query_states = apply_rotary_emb(query_states.transpose(1, 2), cos, sin, inplace=True).transpose(1, 2)
+ key_states = apply_rotary_emb(key_states.transpose(1, 2), cos, sin, inplace=True).transpose(1, 2)
if past_key_value is not None:
# reuse k, v, self_attention |
Why not just convert the HF weights to use the Llama implementation in this repo?
You can also see how we use rotary in MHA here. |
Thanks for the pointer. I know I can convert the weights and use the trainer here, but I'm interested in features that transformers offers out of the box, such as loading weights in 4bit/8bit and using PEFT techniques such as LoRA, QLoRA and IA3. I'm also just trying to understand how this stuff works better, so translating between the two implementations is helpful as a learning exercise. I'll read through the MHA implementation and see what I can figure out. |
Okay I see I should probably be using the On the transformers side there are several variations: class LlamaRotaryEmbedding(torch.nn.Module):
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): I'll see what it takes to get these variations into the fa2 rotary module first |
Yeah we should probably add those as arguments to the RotaryEmbedding module (scaling factor, scaling_method="ntk" or scaling_method="standard"). Is "standard" the right name or is there another name? |
Have you been able to look at this? I was also wondering how would one use the fused MLP layers with huggingface? |
You can use this subclass with the HF trainer: #486 |
Can it be used in 4bit and peft like hf models? |
Yes, but not with fused MLP because there's no place for peft to hook into the linear layers. |
Thanks a lot. Can you give a pointer how to load that model in 4bit using huggingface? |
I dont think we can use that model in 4bit |
Is there any information about training speed comparing HF vs FA2 llama2 models ? |
hi, i'm looking over the optimizations in the trainer here, and trying to port them to the
transformers.trainer.Trainer
for use with llama2i put together this simple script to view the differences between the two:
and I see:
i'm trying to understand the differences here.
it appears there's an extra
RMSNorm
inserted between the attn and mlp? was this intentional?it also looks like
GatedMlp
only has two linear layers, but the first is double sized. what's going on there?i saw there's no fused version of
GatedMlp
yet. is there any reason to use it overLLamaMLP
if i'm not doing tensor parallel?The text was updated successfully, but these errors were encountered: