Skip to content

Commit

Permalink
feature(nvidia): add paged_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
kilinchange committed Apr 11, 2024
1 parent 157fe96 commit c96e3d7
Show file tree
Hide file tree
Showing 2 changed files with 542 additions and 0 deletions.
176 changes: 176 additions & 0 deletions nvidia/common/src/paged_attention.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#pragma once

#include <float.h>
#include <type_traits>

// FP16 vector types for Q, K, V.
template <>
struct Vec<uint16_t, 1> {
using Type = uint16_t;
};
template <>
struct Vec<uint16_t, 2> {
using Type = uint32_t;
};
template <>
struct Vec<uint16_t, 4> {
using Type = uint2;
};
template <>
struct Vec<uint16_t, 8> {
using Type = uint4;
};

// FP32 accumulator vector types corresponding to Vec.
template <>
struct FloatVec<uint16_t> {
using Type = float;
};
template <>
struct FloatVec<uint32_t> {
using Type = float2;
};
template <>
struct FloatVec<uint2> {
using Type = Float4_;
};
template <>
struct FloatVec<uint4> {
using Type = Float8_;
};

// From float32 to float16.
inline __device__ void from_float(uint16_t& dst, float src) { dst = float_to_half(src); }

inline __device__ void from_float(uint32_t& dst, float2 src) { dst = float2_to_half2(src); }

inline __device__ void from_float(uint2& dst, Float4_ src) {
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
}

inline __device__ void from_float(uint4& dst, Float8_ src) {
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
dst.z = float2_to_half2(src.z);
dst.w = float2_to_half2(src.w);
}

// Vector fused multiply-add.
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
return d;
}

inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { return fma(h0_h0(a), b, c); }

inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
uint2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}

inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
uint32_t s = h0_h0(a);
uint2 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}

inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
uint4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}

inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
uint32_t s = h0_h0(a);
uint4 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}

inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb + fc;
}

inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return fma(fa, fb, fc);
}

inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { return fma(h0_h0(a), b, fc); }

inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}

inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
uint32_t s = h0_h0(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}

inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}

inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
uint32_t s = h0_h0(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}

// Q*K^T operation.
template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}

// Finalize the reduction across lanes.
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
return qk;
}

template <typename T, int THREAD_GROUP_SIZE>
struct Qk_dot {
template <typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
}
};
Loading

0 comments on commit c96e3d7

Please sign in to comment.