diff --git a/litgpt/model.py b/litgpt/model.py index 89eb007948..062ae5701a 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -386,8 +386,8 @@ def forward( q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*) if self.config.norm_qk: - q = self.norm_q(q) - k = self.norm_k(k) + q = self.q_norm(q) + k = self.k_norm(k) # To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the # embedding size (C) into num_heads (nh) and head_size (hs).