-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(transformer-nvidia): 添加 kv cache,实现 reform
Signed-off-by: YdrMaster <[email protected]>
- Loading branch information
Showing
10 changed files
with
223 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use crate::{tensor, DevMem}; | ||
use cuda::Stream; | ||
use model_parameters::Llama2; | ||
use tensor::{udim, Tensor}; | ||
|
||
pub struct LayerCache<'a> { | ||
/// Key cache, shape = `num_kv_head x max_seq_len x head_dim`. | ||
k: Tensor<DevMem<'a>>, | ||
/// Value cache, shape = `num_kv_head x max_seq_len x head_dim`. | ||
v: Tensor<DevMem<'a>>, | ||
} | ||
|
||
impl<'a> LayerCache<'a> { | ||
pub fn new_layers(model: &dyn Llama2, stream: &'a Stream) -> Vec<Self> { | ||
let dt = model.data_type(); | ||
let nkvh = model.num_key_value_heads() as udim; | ||
let hd = (model.hidden_size() / model.num_attention_heads()) as udim; | ||
let max_seq_len = model.max_position_embeddings() as udim; | ||
let shape = &[nkvh, max_seq_len, hd]; | ||
(0..model.num_hidden_layers()) | ||
.map(|_| Self { | ||
k: tensor(dt, shape, stream), | ||
v: tensor(dt, shape, stream), | ||
}) | ||
.collect() | ||
} | ||
|
||
#[inline] | ||
pub fn get(&self) -> (&'a Tensor<DevMem>, &'a Tensor<DevMem>) { | ||
(&self.k, &self.v) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
mod gather; | ||
mod mat_mul; | ||
mod reform; | ||
mod rms_norm; | ||
mod rotary_embedding; | ||
|
||
pub(crate) use gather::gather; | ||
pub(crate) use mat_mul::mat_mul; | ||
pub(crate) use reform::Reform; | ||
pub(crate) use rms_norm::RmsNormalization; | ||
pub(crate) use rotary_embedding::RotaryEmbedding; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
template<class Tmem> | ||
static __device__ void reform( | ||
void *__restrict__ dst, | ||
unsigned int const rsa, | ||
unsigned int const csa, | ||
void const *__restrict__ src, | ||
unsigned int const rsb, | ||
unsigned int const csb, | ||
unsigned int const ncols) { | ||
|
||
auto row = blockIdx.y, | ||
col = blockIdx.x * blockDim.y + threadIdx.y; | ||
if (col >= ncols) return; | ||
|
||
auto thread = threadIdx.x, | ||
warp_size = blockDim.x; | ||
auto i = (row * rsa + col * csa) * warp_size + thread; | ||
auto j = (row * rsb + col * csb) * warp_size + thread; | ||
// printf("%d %d %d %d: row = %d, col = %d, nrows = %d, ncols = %d, rsa = %d, rsb = %d, csa = %d, csb = %d, warp_size = %d, thread = %d, i = %d, j = %d\n", | ||
// blockIdx.y, blockIdx.x, threadIdx.y, threadIdx.x, row, col, gridDim.y, ncols, rsa, rsb, csa, csb, warp_size, thread, i, j); | ||
|
||
reinterpret_cast<Tmem *>(dst)[i] = reinterpret_cast<Tmem const *>(src)[j]; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
use crate::storage::DevMem; | ||
use cuda::{bindings::CUdeviceptr, AsRaw, ContextGuard, KernelFn, Stream}; | ||
use std::ffi::c_void; | ||
use tensor::{udim, Tensor}; | ||
|
||
pub struct Reform { | ||
f: KernelFn, | ||
block_size: udim, | ||
warp_size: udim, | ||
} | ||
|
||
impl Reform { | ||
pub fn new(block_size: usize, warp_size: usize, ctx: &ContextGuard) -> Self { | ||
assert_eq!( | ||
block_size % warp_size, | ||
0, | ||
"block_size must be a multiple of warp_size" | ||
); | ||
|
||
let name = "reform"; | ||
|
||
const REFORM: &str = include_str!("reform.cuh"); | ||
let code = format!( | ||
r#"{REFORM} | ||
extern "C" __global__ void {name}( | ||
void *__restrict__ dst, | ||
unsigned int const rsa, | ||
unsigned int const csa, | ||
void const *__restrict__ src, | ||
unsigned int const rsb, | ||
unsigned int const csb, | ||
unsigned int const ncols, | ||
unsigned int const bytes_per_thread | ||
){{ | ||
switch (bytes_per_thread) {{ | ||
case 1: reform<uchar1 >(dst, rsa, csa, src, rsb, csb, ncols); break; | ||
case 2: reform<uchar2 >(dst, rsa, csa, src, rsb, csb, ncols); break; | ||
case 4: reform<float1 >(dst, rsa, csa, src, rsb, csb, ncols); break; | ||
case 8: reform<float2 >(dst, rsa, csa, src, rsb, csb, ncols); break; | ||
case 16: reform<float4 >(dst, rsa, csa, src, rsb, csb, ncols); break; | ||
case 32: reform<double4>(dst, rsa, csa, src, rsb, csb, ncols); break; | ||
}} | ||
}} | ||
"# | ||
); | ||
|
||
ctx.compile(code); | ||
Self { | ||
f: KernelFn::get(name).unwrap(), | ||
block_size: block_size as _, | ||
warp_size: warp_size as _, | ||
} | ||
} | ||
|
||
pub fn launch(&self, dst: &Tensor<DevMem>, src: &Tensor<DevMem>, stream: &Stream) { | ||
assert_eq!(dst.data_type(), src.data_type()); | ||
assert_eq!(dst.shape(), src.shape()); | ||
|
||
let &[r, c, b] = dst.shape() else { | ||
unreachable!() | ||
}; | ||
let &[rsa, csa, 1] = dst.strides() else { | ||
unreachable!() | ||
}; | ||
let &[rsb, csb, 1] = src.strides() else { | ||
unreachable!() | ||
}; | ||
|
||
let contiguous_bytes = b * dst.data_type().size() as udim; | ||
assert_eq!(contiguous_bytes % self.warp_size, 0); | ||
let bytes_per_thread = contiguous_bytes / self.warp_size; | ||
assert!(bytes_per_thread <= 32 && bytes_per_thread.is_power_of_two()); | ||
|
||
let dst_ptr = | ||
(unsafe { dst.physical().as_raw() } as isize + dst.bytes_offset()) as CUdeviceptr; | ||
let rsa = rsa as udim / b; | ||
let csa = csa as udim / b; | ||
let src_ptr = | ||
(unsafe { src.physical().as_raw() } as isize + src.bytes_offset()) as CUdeviceptr; | ||
let rsb = rsb as udim / b; | ||
let csb = csb as udim / b; | ||
let params: [*const c_void; 8] = [ | ||
(&dst_ptr) as *const _ as _, | ||
(&rsa) as *const _ as _, | ||
(&csa) as *const _ as _, | ||
(&src_ptr) as *const _ as _, | ||
(&rsb) as *const _ as _, | ||
(&csb) as *const _ as _, | ||
(&c) as *const _ as _, | ||
(&bytes_per_thread) as *const _ as _, | ||
]; | ||
|
||
let max_warp_per_block = self.block_size / self.warp_size; | ||
let grid_dims = ((c + max_warp_per_block - 1) / max_warp_per_block, r); | ||
let block_dims = (self.warp_size, (c + grid_dims.0 - 1) / grid_dims.0); | ||
self.f | ||
.launch(grid_dims, block_dims, params.as_ptr(), 0, Some(stream)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.