📚FA2: QK Fine-grained Tiling
What's Changed
- [FA2] hotfix flash-attn-mma smem size setting✔️ by @DefTruth in #170
- [FA2] reorder grid layout, boost 5~10% TFLOPS✔️ by @DefTruth in #171
- [FA2] optimize block tiling for headdim >= 128✔️ by @DefTruth in #172
- [FA2] flash-attn-mma tiling-qk for large d⚡️ by @DefTruth in #173
- [FA2] fix tiling-qk misaligned address✔️ by @DefTruth in #174
- [README] Refactor README.md✔️ by @DefTruth in #175
- [README] Refactor README✔️ by @DefTruth in #176
📚 Split Q + QK Fine-grained Tiling (O(16xd) SRAM vs FA2 O(4xBrxd) SRAM, Headdim -> 1024
)
// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
// extend D (head dimension) up to 1024. Performance is stay tuned for updates ~
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
Full Changelog: v2.6.9...v2.6.10