Skip to content

Commit

Permalink
refactor(transformer-nvidia): 更新 cuda crate,显存抽象移动到上游
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 1, 2024
1 parent cff122d commit a436a88
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 116 deletions.
10 changes: 5 additions & 5 deletions transformer-nvidia/src/cache.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use crate::{tensor, DevMem};
use cuda::Stream;
use crate::tensor;
use cuda::{LocalDevBlob, Stream};
use model_parameters::Llama2;
use tensor::{udim, Tensor};

pub struct LayerCache<'a> {
/// Key cache, shape = `num_kv_head x max_seq_len x head_dim`.
k: Tensor<DevMem<'a>>,
k: Tensor<LocalDevBlob<'a>>,
/// Value cache, shape = `num_kv_head x max_seq_len x head_dim`.
v: Tensor<DevMem<'a>>,
v: Tensor<LocalDevBlob<'a>>,
}

impl<'a> LayerCache<'a> {
Expand All @@ -26,7 +26,7 @@ impl<'a> LayerCache<'a> {
}

#[inline]
pub fn get(&self) -> (&'a Tensor<DevMem>, &'a Tensor<DevMem>) {
pub fn get(&self) -> (&'a Tensor<LocalDevBlob>, &'a Tensor<LocalDevBlob>) {
(&self.k, &self.v)
}
}
7 changes: 4 additions & 3 deletions transformer-nvidia/src/kernel/fused_softmax.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::storage::DevMem;
use cuda::{bindings::CUdeviceptr, AsRaw, ContextGuard, CudaDataType, KernelFn, Stream};
use cuda::{
bindings::CUdeviceptr, AsRaw, ContextGuard, CudaDataType, KernelFn, LocalDevBlob, Stream,
};
use std::ffi::{c_uint, c_void};
use tensor::Tensor;

Expand Down Expand Up @@ -59,7 +60,7 @@ extern "C" __global__ void {folding}(
}
}

pub fn launch(&self, att: &Tensor<DevMem>, stream: &Stream) {
pub fn launch(&self, att: &Tensor<LocalDevBlob>, stream: &Stream) {
assert!(att.is_contiguous());
let &[nh, seq_len, att_len] = att.shape() else {
panic!("Invalid attention shape");
Expand Down
7 changes: 3 additions & 4 deletions transformer-nvidia/src/kernel/gather.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use crate::storage::DevMem;
use common::utok;
use cuda::{bindings::CUdeviceptr, AsRaw, Stream};
use common::utok;
use cuda::{bindings::CUdeviceptr, AsRaw, LocalDevBlob, Stream};
use std::ops::Deref;
use tensor::Tensor;

pub fn gather<T>(x: &Tensor<DevMem>, table: &Tensor<T>, tokens: &[utok], stream: &Stream)
pub fn gather<T>(x: &Tensor<LocalDevBlob>, table: &Tensor<T>, tokens: &[utok], stream: &Stream)
where
T: Deref<Target = [u8]>,
{
Expand Down
15 changes: 7 additions & 8 deletions transformer-nvidia/src/kernel/mat_mul.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::storage::DevMem;
use cublas::{bindings::cublasOperation_t, cublas};
use cuda::AsRaw;
use cublas::{bindings::cublasOperation_t, cublas};
use cuda::{AsRaw, LocalDevBlob};
use half::f16;
use std::{
ffi::{c_int, c_longlong, c_void},
Expand All @@ -10,10 +9,10 @@ use tensor::{DataType, Tensor};

pub fn mat_mul(
handle: cublas::bindings::cublasHandle_t,
c: &Tensor<DevMem>,
c: &Tensor<LocalDevBlob>,
beta: f32,
a: &Tensor<DevMem>,
b: &Tensor<DevMem>,
a: &Tensor<LocalDevBlob>,
b: &Tensor<LocalDevBlob>,
alpha: f32,
) {
assert_eq!(c.data_type(), DataType::F16);
Expand Down Expand Up @@ -91,8 +90,8 @@ struct Matrix {
ptr: *mut c_void,
}

impl From<&Tensor<DevMem<'_>>> for Matrix {
fn from(tensor: &Tensor<DevMem>) -> Self {
impl From<&Tensor<LocalDevBlob<'_>>> for Matrix {
fn from(tensor: &Tensor<LocalDevBlob>) -> Self {
let strides = tensor.strides();
let ptr = (unsafe { tensor.physical().as_raw() } as isize + tensor.bytes_offset()) as _;
match tensor.shape() {
Expand Down
5 changes: 2 additions & 3 deletions transformer-nvidia/src/kernel/reform.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::storage::DevMem;
use cuda::{bindings::CUdeviceptr, AsRaw, ContextGuard, KernelFn, Stream};
use cuda::{bindings::CUdeviceptr, AsRaw, ContextGuard, KernelFn, LocalDevBlob, Stream};
use std::ffi::{c_uint, c_void};
use tensor::{udim, Tensor};

Expand Down Expand Up @@ -53,7 +52,7 @@ extern "C" __global__ void {name}(
}
}

pub fn launch(&self, dst: &Tensor<DevMem>, src: &Tensor<DevMem>, stream: &Stream) {
pub fn launch(&self, dst: &Tensor<LocalDevBlob>, src: &Tensor<LocalDevBlob>, stream: &Stream) {
assert_eq!(dst.data_type(), src.data_type());
assert_eq!(dst.shape(), src.shape());

Expand Down
9 changes: 4 additions & 5 deletions transformer-nvidia/src/kernel/rms_norm.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::DevMem;
use cuda::{AsRaw, ContextGuard, CudaDataType, KernelFn, Stream};
use cuda::{AsRaw, ContextGuard, CudaDataType, KernelFn, LocalDevBlob, Stream};
use std::ffi::{c_uint, c_void};

pub struct RmsNormalization {
Expand Down Expand Up @@ -63,9 +62,9 @@ extern "C" __global__ void {folding}(

pub fn launch(
&self,
y: &DevMem,
x: &DevMem,
w: &DevMem,
y: &LocalDevBlob,
x: &LocalDevBlob,
w: &LocalDevBlob,
epsilon: f32,
leading_dim: usize,
stream: &Stream,
Expand Down
11 changes: 8 additions & 3 deletions transformer-nvidia/src/kernel/rotary_embedding.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::storage::DevMem;
use cuda::{bindings::CUdeviceptr, AsRaw, ContextGuard, KernelFn, Stream};
use cuda::{bindings::CUdeviceptr, AsRaw, ContextGuard, KernelFn, LocalDevBlob, Stream};
use std::ffi::{c_uint, c_void};
use tensor::{udim, DataType, Tensor};

Expand Down Expand Up @@ -34,7 +33,13 @@ extern "C" __global__ void {name}(
}
}

pub fn launch(&self, t: &Tensor<DevMem>, pos: &Tensor<DevMem>, theta: f32, stream: &Stream) {
pub fn launch(
&self,
t: &Tensor<LocalDevBlob>,
pos: &Tensor<LocalDevBlob>,
theta: f32,
stream: &Stream,
) {
let &[n, nh, dh] = t.shape() else {
panic!("Invalid shape");
};
Expand Down
7 changes: 4 additions & 3 deletions transformer-nvidia/src/kernel/swiglu.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::storage::DevMem;
use cuda::{bindings::CUdeviceptr, AsRaw, ContextGuard, CudaDataType, KernelFn, Stream};
use cuda::{
bindings::CUdeviceptr, AsRaw, ContextGuard, CudaDataType, KernelFn, LocalDevBlob, Stream,
};
use std::ffi::{c_uint, c_void};
use tensor::{udim, Tensor};

Expand Down Expand Up @@ -35,7 +36,7 @@ extern "C" __global__ void {name}(
}
}

pub fn launch(&self, gate: &Tensor<DevMem>, up: &Tensor<DevMem>, stream: &Stream) {
pub fn launch(&self, gate: &Tensor<LocalDevBlob>, up: &Tensor<LocalDevBlob>, stream: &Stream) {
assert_eq!(gate.data_type(), up.data_type());
assert_eq!(gate.shape(), up.shape());

Expand Down
15 changes: 7 additions & 8 deletions transformer-nvidia/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@ mod storage;
use ::half::f16;
use common::{upos, utok};
use cublas::{bindings as cublas_def, cublas};
use cuda::{AsRaw, CudaDataType::half, Stream};
use cuda::{AsRaw, CudaDataType::half, LocalDevBlob, Stream};
use kernel::{gather, mat_mul, FusedSoftmax, Reform, RmsNormalization, RotaryEmbedding, Swiglu};
use model_parameters::Llama2;
use parameters::{LayersParameters, ModelParameters};
use std::ptr::null_mut;
use storage::DevMem;
use tensor::{slice, udim, DataType, Tensor};

pub use cache::LayerCache;
Expand All @@ -33,7 +32,7 @@ pub struct Transformer<'a> {
fused_softmax: FusedSoftmax,
swiglu: Swiglu,

logits_dev: Tensor<DevMem<'a>>,
logits_dev: Tensor<LocalDevBlob<'a>>,
logits: Vec<f32>,
}

Expand Down Expand Up @@ -84,7 +83,7 @@ impl<'a> Transformer<'a> {
pos: upos,
compute: &Stream,
transfer: &'b Stream,
) -> Tensor<DevMem<'b>> {
) -> Tensor<LocalDevBlob<'b>> {
let seq_len = tokens.len() as udim;
let d = self.host.hidden_size() as udim;
let nh = self.host.num_attention_heads() as udim;
Expand All @@ -100,7 +99,7 @@ impl<'a> Transformer<'a> {
let att_len = pos + seq_len;
let cat_slice = &[slice![all], slice![pos; 1; seq_len], slice![all]];
let att_slice = &[slice![all], slice![ 0; 1; att_len], slice![all]];
let pos = DevMem::from_slice(&(pos..pos + seq_len).collect::<Vec<udim>>(), transfer);
let pos = transfer.from_host(&(pos..pos + seq_len).collect::<Vec<udim>>());
let pos = Tensor::new(DataType::U32, &[seq_len], pos);
// println!("tokens: {tokens:?}");

Expand Down Expand Up @@ -284,16 +283,16 @@ impl<'a> Transformer<'a> {
}

#[inline]
fn tensor<'a>(dt: DataType, shape: &[udim], stream: &'a Stream) -> Tensor<DevMem<'a>> {
fn tensor<'a>(dt: DataType, shape: &[udim], stream: &'a Stream) -> Tensor<LocalDevBlob<'a>> {
Tensor::new(
dt,
shape,
DevMem::new(shape.iter().product::<udim>() as usize * dt.size(), stream),
stream.malloc::<u8>(shape.iter().product::<udim>() as usize * dt.size()),
)
}

#[allow(unused)]
fn map_tensor(tensor: &Tensor<DevMem>) -> Tensor<Vec<u8>> {
fn map_tensor(tensor: &Tensor<LocalDevBlob>) -> Tensor<Vec<u8>> {
unsafe {
tensor.map_physical(|dev| {
let len = dev.len();
Expand Down
28 changes: 12 additions & 16 deletions transformer-nvidia/src/parameters.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
use crate::DevMem;
use cuda::Stream;
use cuda::{LocalDevBlob, Stream};
use model_parameters::Llama2;
use tensor::Tensor;

pub(crate) struct ModelParameters<'a> {
pub(crate) model_norm: Tensor<DevMem<'a>>,
pub(crate) lm_head: Tensor<DevMem<'a>>,
pub(crate) model_norm: Tensor<LocalDevBlob<'a>>,
pub(crate) lm_head: Tensor<LocalDevBlob<'a>>,
pub(crate) sync_event: cuda::Event,
}

impl<'a> ModelParameters<'a> {
pub fn new(host: &dyn Llama2, stream: &'a Stream) -> Self {
macro_rules! map {
($param:ident) => {
unsafe {
host.$param()
.map_physical(|slice| DevMem::from_slice(slice, stream))
}
unsafe { host.$param().map_physical(|slice| stream.from_host(slice)) }
};
}
Self {
Expand Down Expand Up @@ -71,12 +67,12 @@ impl<'a> LayersParameters<'a> {
}

pub(crate) struct LayerParameter<'a> {
pub input_layernorm: Tensor<DevMem<'a>>,
pub w_qkv: Tensor<DevMem<'a>>,
pub self_attn_o_proj: Tensor<DevMem<'a>>,
pub post_attention_layernorm: Tensor<DevMem<'a>>,
pub mlp_gate_up: Tensor<DevMem<'a>>,
pub mlp_down: Tensor<DevMem<'a>>,
pub input_layernorm: Tensor<LocalDevBlob<'a>>,
pub w_qkv: Tensor<LocalDevBlob<'a>>,
pub self_attn_o_proj: Tensor<LocalDevBlob<'a>>,
pub post_attention_layernorm: Tensor<LocalDevBlob<'a>>,
pub mlp_gate_up: Tensor<LocalDevBlob<'a>>,
pub mlp_down: Tensor<LocalDevBlob<'a>>,

layer: usize,
sync_event: cuda::Event,
Expand All @@ -88,7 +84,7 @@ impl<'a> LayerParameter<'a> {
($param:ident) => {
unsafe {
host.$param(layer)
.map_physical(|slice| DevMem::from_slice(slice, stream))
.map_physical(|slice| stream.from_host(slice))
}
};
}
Expand All @@ -113,7 +109,7 @@ impl<'a> LayerParameter<'a> {
($param:ident) => {
self.$param
.physical_mut()
.copy_in(host.$param(layer).as_slice(), stream)
.copy_in_async(host.$param(layer).as_slice(), stream)
};
}
update!(input_layernorm);
Expand Down
59 changes: 1 addition & 58 deletions transformer-nvidia/src/storage.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use cuda::{AsRaw, Context, ContextGuard, Stream};
use cuda::{Context, ContextGuard};
use std::{
mem::size_of_val,
ops::{Deref, DerefMut},
ptr::null_mut,
sync::Arc,
Expand Down Expand Up @@ -47,59 +46,3 @@ impl DerefMut for PageLockedMemory {
unsafe { std::slice::from_raw_parts_mut(self.ptr as _, self.len) }
}
}

#[derive(Clone)]
pub struct DevMem<'a> {
ptr: cuda::bindings::CUdeviceptr,
len: usize,
_stream: &'a Stream<'a>,
}

impl<'a> DevMem<'a> {
pub fn new(len: usize, stream: &'a Stream) -> Self {
let mut ptr = 0;
cuda::driver!(cuMemAllocAsync(&mut ptr, len, stream.as_raw()));
Self {
ptr,
len,
_stream: stream,
}
}

pub fn from_slice<T: Copy>(slice: &[T], stream: &'a Stream) -> Self {
let stream_ = unsafe { stream.as_raw() };
let len = size_of_val(slice);
let src = slice.as_ptr().cast();
let mut ptr = 0;
cuda::driver!(cuMemAllocAsync(&mut ptr, len, stream_));
cuda::driver!(cuMemcpyHtoDAsync_v2(ptr, src, len, stream_));
Self {
ptr,
len,
_stream: stream,
}
}
}

impl DevMem<'_> {
#[inline]
pub fn len(&self) -> usize {
self.len
}

pub fn copy_in<T: Copy>(&mut self, slice: &[T], stream: &Stream) {
let len = size_of_val(slice);
let src = slice.as_ptr().cast();
assert_eq!(len, self.len);
cuda::driver!(cuMemcpyHtoDAsync_v2(self.ptr, src, len, stream.as_raw()));
}
}

impl AsRaw for DevMem<'_> {
type Raw = cuda::bindings::CUdeviceptr;

#[inline]
unsafe fn as_raw(&self) -> Self::Raw {
self.ptr
}
}

0 comments on commit a436a88

Please sign in to comment.