diff --git a/litgpt/model.py b/litgpt/model.py index 54b34cd478..f3c426192d 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -349,6 +349,7 @@ def forward( mask += sliding_window_bias # Efficient attention using Flash Attention CUDA kernels. + # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled. # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) y = self.scaled_dot_product_attention(q, k, v, mask)