Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] RuntimeError: The size of tensor a (2048) must match the size of tensor b (1024) at non-singleton dimension 2 #6910

Open
Lowlowlowlowlowlow opened this issue Dec 24, 2024 · 1 comment
Assignees
Labels
bug Something isn't working deepspeed-chat Related to DeepSpeed-Chat

Comments

@Lowlowlowlowlowlow
Copy link

Describe the bug
It happens during lora of tuning llama3.1 8b, is that deepspeed chat support llama3.1?

Log output
Input shape: torch.Size([1, 1088, 4096])
Weight shape: torch.Size([4096, 4096])
LoRA Right Weight shape: torch.Size([4096, 32])
LoRA Left Weight shape: torch.Size([32, 4096])
Result shape: torch.Size([1, 1088, 4096])
Input shape: torch.Size([1, 1088, 4096])
Weight shape: torch.Size([2048, 4096])
LoRA Right Weight shape: torch.Size([4096, 32])
LoRA Left Weight shape: torch.Size([32, 1024])
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/lgh/models/trustworthy-alignment/training/main.py", line 768, in
[rank0]: main()
[rank0]: File "/home/lgh/models/trustworthy-alignment/training/main.py", line 572, in main
[rank0]: out = trainer.generate_experience(batch_prompt['prompt'],
[rank0]: File "/home/lgh/models/trustworthy-alignment/dschat/rlhf/ppo_trainer.py", line 232, in generate_experience
[rank0]: output = self.actor_model(seq, attention_mask=attention_mask)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank0]: ret_val = func(*args, **kwargs)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1846, in forward
[rank0]: loss = self.module(*inputs, **kwargs)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1163, in forward
[rank0]: outputs = self.model(
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 913, in forward
[rank0]: layer_outputs = decoder_layer(
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 640, in forward
[rank0]: hidden_states, self_attn_weights, present_key_value = self.self_attn(
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 399, in forward
[rank0]: key_states = self.k_proj(hidden_states)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/lgh/anaconda3/envs/rltmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/lgh/models/trustworthy-alignment/dschat/utils/module/lora.py", line 123, in forward
[rank0]: result = F.linear(input, self.weight, self.bias) +
[rank0]: RuntimeError: The size of tensor a (2048) must match the size of tensor b (1024) at non-singleton dimension 2

@Lowlowlowlowlowlow Lowlowlowlowlowlow added bug Something isn't working deepspeed-chat Related to DeepSpeed-Chat labels Dec 24, 2024
@hwchen2017
Copy link
Contributor

hwchen2017 commented Jan 6, 2025

Hi @Lowlowlowlowlowlow, from the log you shared,

Weight shape: torch.Size([2048, 4096])
LoRA Right Weight shape: torch.Size([4096, 32])
LoRA Left Weight shape: torch.Size([32, 1024])

the dimension of these three matrices doesn't match. The multiplication of LoRA left and right weight should produce a matrix that has the same shape of the original weight matrix.
Without your implementation, I guess the shape of your 'LoRA Left Weight shape' is supposed to be [32, 1024].

@hwchen2017 hwchen2017 self-assigned this Jan 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working deepspeed-chat Related to DeepSpeed-Chat
Projects
None yet
Development

No branches or pull requests

2 participants