Skip to content

Commit

Permalink
feat(transformer-nvidia): 添加 kv cache,实现 reform
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 29, 2024
1 parent b39318b commit b91ab28
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 24 deletions.
5 changes: 5 additions & 0 deletions tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ impl<Physical> Tensor<Physical> {
self.pattern.strides()
}

#[inline]
pub fn bytes_offset(&self) -> isize {
self.pattern.offset() as isize * self.data_type.size() as isize
}

#[inline]
pub const fn physical(&self) -> &Physical {
&self.physical
Expand Down
2 changes: 1 addition & 1 deletion transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ mod cache;
mod kernel;
mod storage;

use cache::LayerCache;
use common::{upos, utok};
use gemm::f16;
use kernel::{gather, matmul, rms_norm, rms_norm_inplace, rotary_embedding, softmax, swiglu};
use model_parameters::{Llama2, Memory};
use storage::Storage;
use tensor::{reslice, reslice_mut, slice, udim, DataType, Tensor};

pub use cache::LayerCache;
pub extern crate model_parameters;

pub struct Transformer {
Expand Down
32 changes: 32 additions & 0 deletions transformer-nvidia/src/cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use crate::{tensor, DevMem};
use cuda::Stream;
use model_parameters::Llama2;
use tensor::{udim, Tensor};

pub struct LayerCache<'a> {
/// Key cache, shape = `num_kv_head x max_seq_len x head_dim`.
k: Tensor<DevMem<'a>>,
/// Value cache, shape = `num_kv_head x max_seq_len x head_dim`.
v: Tensor<DevMem<'a>>,
}

impl<'a> LayerCache<'a> {
pub fn new_layers(model: &dyn Llama2, stream: &'a Stream) -> Vec<Self> {
let dt = model.data_type();
let nkvh = model.num_key_value_heads() as udim;
let hd = (model.hidden_size() / model.num_attention_heads()) as udim;
let max_seq_len = model.max_position_embeddings() as udim;
let shape = &[nkvh, max_seq_len, hd];
(0..model.num_hidden_layers())
.map(|_| Self {
k: tensor(dt, shape, stream),
v: tensor(dt, shape, stream),
})
.collect()
}

#[inline]
pub fn get(&self) -> (&'a Tensor<DevMem>, &'a Tensor<DevMem>) {
(&self.k, &self.v)
}
}
4 changes: 1 addition & 3 deletions transformer-nvidia/src/kernel/mat_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ struct Matrix {
impl From<&Tensor<DevMem<'_>>> for Matrix {
fn from(tensor: &Tensor<DevMem>) -> Self {
let strides = tensor.strides();
let base = unsafe { tensor.physical().as_raw() };
let offset = tensor.pattern()[tensor.shape().len()] as cuda::bindings::CUdeviceptr;
let ptr = (base + offset) as _;
let ptr = (unsafe { tensor.physical().as_raw() } as isize + tensor.bytes_offset()) as _;
match tensor.shape() {
&[r, c] => Self {
batch: 1,
Expand Down
2 changes: 2 additions & 0 deletions transformer-nvidia/src/kernel/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod gather;
mod mat_mul;
mod reform;
mod rms_norm;
mod rotary_embedding;

pub(crate) use gather::gather;
pub(crate) use mat_mul::mat_mul;
pub(crate) use reform::Reform;
pub(crate) use rms_norm::RmsNormalization;
pub(crate) use rotary_embedding::RotaryEmbedding;
23 changes: 23 additions & 0 deletions transformer-nvidia/src/kernel/reform.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
template<class Tmem>
static __device__ void reform(
void *__restrict__ dst,
unsigned int const rsa,
unsigned int const csa,
void const *__restrict__ src,
unsigned int const rsb,
unsigned int const csb,
unsigned int const ncols) {

auto row = blockIdx.y,
col = blockIdx.x * blockDim.y + threadIdx.y;
if (col >= ncols) return;

auto thread = threadIdx.x,
warp_size = blockDim.x;
auto i = (row * rsa + col * csa) * warp_size + thread;
auto j = (row * rsb + col * csb) * warp_size + thread;
// printf("%d %d %d %d: row = %d, col = %d, nrows = %d, ncols = %d, rsa = %d, rsb = %d, csa = %d, csb = %d, warp_size = %d, thread = %d, i = %d, j = %d\n",
// blockIdx.y, blockIdx.x, threadIdx.y, threadIdx.x, row, col, gridDim.y, ncols, rsa, rsb, csa, csb, warp_size, thread, i, j);

reinterpret_cast<Tmem *>(dst)[i] = reinterpret_cast<Tmem const *>(src)[j];
}
100 changes: 100 additions & 0 deletions transformer-nvidia/src/kernel/reform.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use crate::storage::DevMem;
use cuda::{bindings::CUdeviceptr, AsRaw, ContextGuard, KernelFn, Stream};
use std::ffi::c_void;
use tensor::{udim, Tensor};

pub struct Reform {
f: KernelFn,
block_size: udim,
warp_size: udim,
}

impl Reform {
pub fn new(block_size: usize, warp_size: usize, ctx: &ContextGuard) -> Self {
assert_eq!(
block_size % warp_size,
0,
"block_size must be a multiple of warp_size"
);

let name = "reform";

const REFORM: &str = include_str!("reform.cuh");
let code = format!(
r#"{REFORM}
extern "C" __global__ void {name}(
void *__restrict__ dst,
unsigned int const rsa,
unsigned int const csa,
void const *__restrict__ src,
unsigned int const rsb,
unsigned int const csb,
unsigned int const ncols,
unsigned int const bytes_per_thread
){{
switch (bytes_per_thread) {{
case 1: reform<uchar1 >(dst, rsa, csa, src, rsb, csb, ncols); break;
case 2: reform<uchar2 >(dst, rsa, csa, src, rsb, csb, ncols); break;
case 4: reform<float1 >(dst, rsa, csa, src, rsb, csb, ncols); break;
case 8: reform<float2 >(dst, rsa, csa, src, rsb, csb, ncols); break;
case 16: reform<float4 >(dst, rsa, csa, src, rsb, csb, ncols); break;
case 32: reform<double4>(dst, rsa, csa, src, rsb, csb, ncols); break;
}}
}}
"#
);

ctx.compile(code);
Self {
f: KernelFn::get(name).unwrap(),
block_size: block_size as _,
warp_size: warp_size as _,
}
}

pub fn launch(&self, dst: &Tensor<DevMem>, src: &Tensor<DevMem>, stream: &Stream) {
assert_eq!(dst.data_type(), src.data_type());
assert_eq!(dst.shape(), src.shape());

let &[r, c, b] = dst.shape() else {
unreachable!()
};
let &[rsa, csa, 1] = dst.strides() else {
unreachable!()
};
let &[rsb, csb, 1] = src.strides() else {
unreachable!()
};

let contiguous_bytes = b * dst.data_type().size() as udim;
assert_eq!(contiguous_bytes % self.warp_size, 0);
let bytes_per_thread = contiguous_bytes / self.warp_size;
assert!(bytes_per_thread <= 32 && bytes_per_thread.is_power_of_two());

let dst_ptr =
(unsafe { dst.physical().as_raw() } as isize + dst.bytes_offset()) as CUdeviceptr;
let rsa = rsa as udim / b;
let csa = csa as udim / b;
let src_ptr =
(unsafe { src.physical().as_raw() } as isize + src.bytes_offset()) as CUdeviceptr;
let rsb = rsb as udim / b;
let csb = csb as udim / b;
let params: [*const c_void; 8] = [
(&dst_ptr) as *const _ as _,
(&rsa) as *const _ as _,
(&csa) as *const _ as _,
(&src_ptr) as *const _ as _,
(&rsb) as *const _ as _,
(&csb) as *const _ as _,
(&c) as *const _ as _,
(&bytes_per_thread) as *const _ as _,
];

let max_warp_per_block = self.block_size / self.warp_size;
let grid_dims = ((c + max_warp_per_block - 1) / max_warp_per_block, r);
let block_dims = (self.warp_size, (c + grid_dims.0 - 1) / grid_dims.0);
self.f
.launch(grid_dims, block_dims, params.as_ptr(), 0, Some(stream));
}
}
15 changes: 8 additions & 7 deletions transformer-nvidia/src/kernel/rotary_embedding.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::storage::DevMem;
use cuda::{AsRaw, ContextGuard, KernelFn, Stream};
use cuda::{bindings::CUdeviceptr, AsRaw, ContextGuard, KernelFn, Stream};
use std::ffi::c_void;
use tensor::{udim, DataType, Tensor};

Expand All @@ -10,13 +10,13 @@ pub struct RotaryEmbedding {

impl RotaryEmbedding {
pub fn new(block_size: usize, ctx: &ContextGuard) -> Self {
let padding = format!("rotary_embedding_padding_{block_size}");
let name = "rotary_embedding_padding";

const ROTARY_EMBEDDING: &str = include_str!("rotary_embedding.cuh");
let code = format!(
r#"{ROTARY_EMBEDDING}
extern "C" __global__ void {padding}(
extern "C" __global__ void {name}(
half2 *__restrict__ x,
unsigned int const *__restrict__ pos,
float theta,
Expand All @@ -29,7 +29,7 @@ extern "C" __global__ void {padding}(

ctx.compile(code);
Self {
f: KernelFn::get(padding).unwrap(),
f: KernelFn::get(name).unwrap(),
block_size: block_size as _,
}
}
Expand All @@ -45,9 +45,10 @@ extern "C" __global__ void {padding}(
assert_eq!(pos.shape(), &[n]);
assert!(dh < self.block_size);

let t_ptr = unsafe { t.physical().as_raw() };
let pos_ptr = unsafe { pos.physical().as_raw() };
let leading_dim = t.strides()[0] as udim;
let t_ptr = (unsafe { t.physical().as_raw() } as isize + t.bytes_offset()) as CUdeviceptr;
let pos_ptr =
(unsafe { pos.physical().as_raw() } as isize + pos.bytes_offset()) as CUdeviceptr;
let leading_dim = t.strides()[0] as udim / 2;
let params: [*const c_void; 4] = [
(&t_ptr) as *const _ as _,
(&pos_ptr) as *const _ as _,
Expand Down
61 changes: 49 additions & 12 deletions transformer-nvidia/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
#![cfg(detected_cuda)]

mod cache;
mod kernel;
mod parameters;
mod storage;

use common::{upos, utok};
use cublas::{bindings as cublas_def, cublas};
use cuda::{AsRaw, CudaDataType::half, Stream};
use kernel::{gather, mat_mul, RmsNormalization, RotaryEmbedding};
use kernel::{gather, mat_mul, Reform, RmsNormalization, RotaryEmbedding};
use model_parameters::Llama2;
use parameters::{LayersParameters, ModelParameters};
use std::ptr::null_mut;
use storage::DevMem;
use tensor::{slice, udim, DataType, Tensor};

pub use cache::LayerCache;
pub use storage::PageLockedMemory;

pub extern crate cuda;
pub extern crate model_parameters;

Expand All @@ -27,6 +28,7 @@ pub struct Transformer<'a> {
cublas: cublas_def::cublasHandle_t,
rms_norm: RmsNormalization,
rotary_embedding: RotaryEmbedding,
reform: Reform,
}

impl Drop for Transformer<'_> {
Expand All @@ -41,28 +43,36 @@ impl<'a> Transformer<'a> {
let d = host.hidden_size();
let mut cublas_handle = null_mut();
cublas!(cublasCreate_v2(&mut cublas_handle));
let ctx = stream.ctx();
let _dev = ctx.dev();
let block_size = 1024;
Self {
host,
model: ModelParameters::new(host, stream),
layers: LayersParameters::new(3, host, stream),

cublas: cublas_handle,
rms_norm: RmsNormalization::new(half, d, 1024, stream.ctx()),
rotary_embedding: RotaryEmbedding::new(1024, stream.ctx()),
rms_norm: RmsNormalization::new(half, d, block_size, stream.ctx()),
rotary_embedding: RotaryEmbedding::new(block_size, stream.ctx()),
reform: Reform::new(block_size, 32, stream.ctx()),
}
}

#[inline]
pub fn new_cache<'b>(&self, stream: &'b Stream) -> Vec<LayerCache<'b>> {
LayerCache::new_layers(&*self.host, &stream)
}

pub fn update(
&mut self,
tokens: &[utok],
// cache: &mut [LayerCache],
cache: &[LayerCache],
pos: upos,
compute: &Stream,
transfer: &Stream,
) {
let seq_len = tokens.len() as udim;
let d = self.host.hidden_size() as udim;
let nlayer = self.host.num_hidden_layers();
let nh = self.host.num_attention_heads() as udim;
let nkvh = self.host.num_key_value_heads() as udim;
let dh = d / nh;
Expand Down Expand Up @@ -98,7 +108,7 @@ impl<'a> Transformer<'a> {

cublas!(cublasSetStream_v2(self.cublas, compute.as_raw() as _));
compute.wait_for(&e_alloc);
for layer in 0..nlayer {
for (layer, cache) in cache.iter().enumerate() {
self.layers.load(layer, self.host, transfer);
let params = self.layers.sync(layer, compute);

Expand All @@ -124,15 +134,42 @@ impl<'a> Transformer<'a> {
// println!("layer {layer} v:\n{}", map_tensor(&v));
self.rotary_embedding.launch(&q, &pos, theta, compute);
self.rotary_embedding.launch(&k, &pos, theta, compute);
println!("layer {layer} rot q:\n{}", map_tensor(&q));
println!("layer {layer} rot k:\n{}", map_tensor(&k));
// compute.synchronize();
// println!("layer {layer} rot q:\n{}", map_tensor(&q));
// println!("layer {layer} rot k:\n{}", map_tensor(&k));
let q = q.transpose(&[1, 0, 2]);
let k = k.transpose(&[1, 0, 2]);
let v = v.transpose(&[1, 0, 2]);

// let (k_cache, v_cache) = cache.get();
// let mut k_cat = k_cache.slice(cat_slice);
// let mut v_cat = v_cache.slice(cat_slice);
let (k_cache, v_cache) = cache.get();
let k_cat = k_cache.slice(cat_slice);
let v_cat = v_cache.slice(cat_slice);
self.reform.launch(&q_att, &q, compute);
self.reform.launch(&k_cat, &k, compute);
self.reform.launch(&v_cat, &v, compute);

let q_att = q_att.clone().reshape(&[nkvh, head_group * seq_len, dh]);
let k_att = k_cache.slice(att_slice);
let v_att = v_cache.slice(att_slice);
// compute.synchronize();
// println!("layer {layer} q attention:\n{}", map_tensor(&q_att));
// println!("layer {layer} k attention:\n{}", map_tensor(&k_att));
// println!("layer {layer} v attention:\n{}", map_tensor(&v_att));

{
let k_att = k_att.transpose(&[0, 2, 1]);
mat_mul(self.cublas, &att, 0., &q_att, &k_att, head_div);
{
let att = att.clone().reshape(&[nh, seq_len, att_len]);
// softmax(&mut att.access_mut());
}
mat_mul(self.cublas, &x2, 0., &att, &v_att, 1.);
}
{
let x2 = x2.clone().reshape(&[nh, seq_len, dh]).transpose(&[1, 0, 2]);
let x1 = x1.clone().reshape(&[seq_len, nh, dh]);
self.reform.launch(&x1, &x2, compute);
}
}
}
}
Expand Down
Loading

0 comments on commit b91ab28

Please sign in to comment.