-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(nvidia): add paged_attention
- Loading branch information
1 parent
157fe96
commit c96e3d7
Showing
2 changed files
with
542 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
#pragma once | ||
|
||
#include <float.h> | ||
#include <type_traits> | ||
|
||
// FP16 vector types for Q, K, V. | ||
template <> | ||
struct Vec<uint16_t, 1> { | ||
using Type = uint16_t; | ||
}; | ||
template <> | ||
struct Vec<uint16_t, 2> { | ||
using Type = uint32_t; | ||
}; | ||
template <> | ||
struct Vec<uint16_t, 4> { | ||
using Type = uint2; | ||
}; | ||
template <> | ||
struct Vec<uint16_t, 8> { | ||
using Type = uint4; | ||
}; | ||
|
||
// FP32 accumulator vector types corresponding to Vec. | ||
template <> | ||
struct FloatVec<uint16_t> { | ||
using Type = float; | ||
}; | ||
template <> | ||
struct FloatVec<uint32_t> { | ||
using Type = float2; | ||
}; | ||
template <> | ||
struct FloatVec<uint2> { | ||
using Type = Float4_; | ||
}; | ||
template <> | ||
struct FloatVec<uint4> { | ||
using Type = Float8_; | ||
}; | ||
|
||
// From float32 to float16. | ||
inline __device__ void from_float(uint16_t& dst, float src) { dst = float_to_half(src); } | ||
|
||
inline __device__ void from_float(uint32_t& dst, float2 src) { dst = float2_to_half2(src); } | ||
|
||
inline __device__ void from_float(uint2& dst, Float4_ src) { | ||
dst.x = float2_to_half2(src.x); | ||
dst.y = float2_to_half2(src.y); | ||
} | ||
|
||
inline __device__ void from_float(uint4& dst, Float8_ src) { | ||
dst.x = float2_to_half2(src.x); | ||
dst.y = float2_to_half2(src.y); | ||
dst.z = float2_to_half2(src.z); | ||
dst.w = float2_to_half2(src.w); | ||
} | ||
|
||
// Vector fused multiply-add. | ||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { | ||
uint32_t d; | ||
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); | ||
return d; | ||
} | ||
|
||
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { return fma(h0_h0(a), b, c); } | ||
|
||
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { | ||
uint2 d; | ||
d.x = fma(a.x, b.x, c.x); | ||
d.y = fma(a.y, b.y, c.y); | ||
return d; | ||
} | ||
|
||
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { | ||
uint32_t s = h0_h0(a); | ||
uint2 d; | ||
d.x = fma(s, b.x, c.x); | ||
d.y = fma(s, b.y, c.y); | ||
return d; | ||
} | ||
|
||
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { | ||
uint4 d; | ||
d.x = fma(a.x, b.x, c.x); | ||
d.y = fma(a.y, b.y, c.y); | ||
d.z = fma(a.z, b.z, c.z); | ||
d.w = fma(a.w, b.w, c.w); | ||
return d; | ||
} | ||
|
||
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { | ||
uint32_t s = h0_h0(a); | ||
uint4 d; | ||
d.x = fma(s, b.x, c.x); | ||
d.y = fma(s, b.y, c.y); | ||
d.z = fma(s, b.z, c.z); | ||
d.w = fma(s, b.w, c.w); | ||
return d; | ||
} | ||
|
||
inline __device__ float fma(uint16_t a, uint16_t b, float fc) { | ||
float fa = half_to_float(a); | ||
float fb = half_to_float(b); | ||
return fa * fb + fc; | ||
} | ||
|
||
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { | ||
float2 fa = half2_to_float2(a); | ||
float2 fb = half2_to_float2(b); | ||
return fma(fa, fb, fc); | ||
} | ||
|
||
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { return fma(h0_h0(a), b, fc); } | ||
|
||
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { | ||
Float4_ fd; | ||
fd.x = fma(a.x, b.x, fc.x); | ||
fd.y = fma(a.y, b.y, fc.y); | ||
return fd; | ||
} | ||
|
||
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { | ||
uint32_t s = h0_h0(a); | ||
Float4_ fd; | ||
fd.x = fma(s, b.x, fc.x); | ||
fd.y = fma(s, b.y, fc.y); | ||
return fd; | ||
} | ||
|
||
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { | ||
Float8_ fd; | ||
fd.x = fma(a.x, b.x, fc.x); | ||
fd.y = fma(a.y, b.y, fc.y); | ||
fd.z = fma(a.z, b.z, fc.z); | ||
fd.w = fma(a.w, b.w, fc.w); | ||
return fd; | ||
} | ||
|
||
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { | ||
uint32_t s = h0_h0(a); | ||
Float8_ fd; | ||
fd.x = fma(s, b.x, fc.x); | ||
fd.y = fma(s, b.y, fc.y); | ||
fd.z = fma(s, b.z, fc.z); | ||
fd.w = fma(s, b.w, fc.w); | ||
return fd; | ||
} | ||
|
||
// Q*K^T operation. | ||
template <int THREAD_GROUP_SIZE, typename Vec, int N> | ||
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { | ||
using A_vec = typename FloatVec<Vec>::Type; | ||
// Compute the parallel products for Q*K^T (treat vector lanes separately). | ||
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]); | ||
#pragma unroll | ||
for (int ii = 1; ii < N; ++ii) { | ||
qk_vec = fma(q[ii], k[ii], qk_vec); | ||
} | ||
|
||
// Finalize the reduction across lanes. | ||
float qk = sum(qk_vec); | ||
#pragma unroll | ||
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { | ||
qk += __shfl_xor_sync(uint32_t(-1), qk, mask); | ||
} | ||
return qk; | ||
} | ||
|
||
template <typename T, int THREAD_GROUP_SIZE> | ||
struct Qk_dot { | ||
template <typename Vec, int N> | ||
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { | ||
return qk_dot_<THREAD_GROUP_SIZE>(q, k); | ||
} | ||
}; |
Oops, something went wrong.