Skip to content

Commit

Permalink
refactor(transformer-nvidia): 整理、优化 transformer nvidia
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
YdrMaster committed Mar 15, 2024
1 parent 367de90 commit f66bcd2
Showing 14 changed files with 126 additions and 269 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions tensor/src/lib.rs
Original file line number Diff line number Diff line change
@@ -3,7 +3,6 @@ mod compatibility;
mod data_type;
mod fmt;
mod pattern;
mod physical_cell;
mod reshape;
mod slice;
mod split;
@@ -20,7 +19,6 @@ pub use compatibility::Compatibility;
pub use data_type::DataType;
pub use nalgebra::DVector;
pub use pattern::{expand_indices, idx_strides, Affine, Shape};
pub use physical_cell::PhysicalCell;
pub use slice::SliceDim;
pub use split::Splitable;
pub use tensor::Tensor;
41 changes: 0 additions & 41 deletions tensor/src/physical_cell.rs

This file was deleted.

2 changes: 1 addition & 1 deletion transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
@@ -219,7 +219,7 @@ impl Transformer {
x.as_ref()
.map_physical(|u| std::slice::from_raw_parts(u.as_ptr(), u.len()))
};
rms_norm(&mut x, &x_, &self.0.model_norm(), self.0.rms_norm_eps());
rms_norm(&mut x, &x_, &self.0.model_norm(), epsilon);
// println!("model norm:\n{}", x.access());

let lm_head = self.0.lm_head().transpose(&[1, 0]);
6 changes: 3 additions & 3 deletions transformer-nvidia/src/kernel/fused_softmax.rs
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
};
use std::{
ffi::{c_uint, c_void, CString},
ops::Deref,
ops::DerefMut,
};
use tensor::Tensor;

@@ -70,9 +70,9 @@ extern "C" __global__ void {folding}(
}

impl FusedSoftmax<'_> {
pub fn launch<T>(&self, att: &Tensor<T>, stream: &Stream)
pub fn launch<T>(&self, att: &mut Tensor<T>, stream: &Stream)
where
T: Deref<Target = DevSlice>,
T: DerefMut<Target = DevSlice>,
{
assert!(att.is_contiguous());
let &[nh, seq_len, att_len] = att.shape() else {
6 changes: 3 additions & 3 deletions transformer-nvidia/src/kernel/gather.rs
Original file line number Diff line number Diff line change
@@ -3,11 +3,11 @@ use cuda::{bindings::CUdeviceptr, AsRaw, DevSlice, Stream};
use std::ops::{Deref, DerefMut};
use tensor::Tensor;

pub fn gather<'a, T, U, I>(x: Tensor<T>, table: &Tensor<U>, requests: I, stream: &Stream)
pub fn gather<T, U, I>(x: &mut Tensor<T>, table: &Tensor<U>, tokens: I, stream: &Stream)
where
T: DerefMut<Target = DevSlice>,
U: Deref<Target = [u8]>,
I: IntoIterator<Item = &'a [utok]>,
I: IntoIterator<Item = utok>,
{
let &[_, d] = x.shape() else { panic!() };

@@ -21,7 +21,7 @@ where
let x = unsafe { x.physical().as_raw() };
let table = table.as_slice();
let stream = unsafe { stream.as_raw() };
for (i, &t) in requests.into_iter().flatten().enumerate() {
for (i, t) in tokens.into_iter().enumerate() {
let src = table[d * t as usize..].as_ptr().cast();
let dst = x + (d * i) as CUdeviceptr;
cuda::driver!(cuMemcpyHtoDAsync_v2(dst, src, d, stream));
16 changes: 9 additions & 7 deletions transformer-nvidia/src/kernel/mat_mul.rs
Original file line number Diff line number Diff line change
@@ -4,25 +4,27 @@ use half::f16;
use std::{
ffi::{c_int, c_longlong, c_void},
mem::swap,
ops::Deref,
ops::{Deref, DerefMut},
};
use tensor::{DataType, Tensor};

pub fn mat_mul<T>(
pub fn mat_mul<T, U, V>(
handle: &Cublas,
c: &Tensor<T>,
c: &mut Tensor<T>,
beta: f32,
a: &Tensor<T>,
b: &Tensor<T>,
a: &Tensor<U>,
b: &Tensor<V>,
alpha: f32,
) where
T: Deref<Target = DevSlice>,
T: DerefMut<Target = DevSlice>,
U: Deref<Target = DevSlice>,
V: Deref<Target = DevSlice>,
{
assert_eq!(c.data_type(), DataType::F16);
assert_eq!(a.data_type(), DataType::F16);
assert_eq!(b.data_type(), DataType::F16);

let mut c = Matrix::from(c);
let mut c = Matrix::from(&*c);
let mut a = Matrix::from(a);
let mut b = Matrix::from(b);
assert_eq!(c.r, a.r); // m
2 changes: 1 addition & 1 deletion transformer-nvidia/src/kernel/reform.rs
Original file line number Diff line number Diff line change
@@ -62,7 +62,7 @@ extern "C" __global__ void {name}(
}

impl Reform<'_> {
pub fn launch<T, U>(&self, dst: Tensor<T>, src: &Tensor<U>, stream: &Stream)
pub fn launch<T, U>(&self, dst: &mut Tensor<T>, src: &Tensor<U>, stream: &Stream)
where
T: DerefMut<Target = DevSlice>,
U: Deref<Target = DevSlice>,
2 changes: 1 addition & 1 deletion transformer-nvidia/src/kernel/rms_norm.rs
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ extern "C" __global__ void {folding}(
impl RmsNormalization<'_> {
pub fn launch<T, U, V>(
&self,
y: Tensor<T>,
y: &mut Tensor<T>,
x: &Tensor<U>,
w: &Tensor<V>,
epsilon: f32,
2 changes: 1 addition & 1 deletion transformer-nvidia/src/kernel/rotary_embedding.rs
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ extern "C" __global__ void {name}(
}

impl RotaryEmbedding<'_> {
pub fn launch<T, U>(&self, t: &Tensor<T>, pos: &Tensor<U>, theta: f32, stream: &Stream)
pub fn launch<T, U>(&self, t: &mut Tensor<T>, pos: &Tensor<U>, theta: f32, stream: &Stream)
where
T: Deref<Target = DevSlice>,
U: Deref<Target = DevSlice>,
2 changes: 1 addition & 1 deletion transformer-nvidia/src/kernel/swiglu.rs
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@ extern "C" __global__ void {name}(
}

impl Swiglu<'_> {
pub fn launch<T>(&self, gate: &Tensor<T>, up: &Tensor<T>, stream: &Stream)
pub fn launch<T>(&self, gate: &mut Tensor<T>, up: &Tensor<T>, stream: &Stream)
where
T: Deref<Target = DevSlice>,
{
240 changes: 82 additions & 158 deletions transformer-nvidia/src/lib.rs

Large diffs are not rendered by default.

11 changes: 5 additions & 6 deletions transformer-nvidia/src/parameters.rs
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ impl<'ctx> ModelParameters<'ctx> {
}
Self {
model_norm: map!(model_norm),
lm_head: map!(lm_head),
lm_head: map!(lm_head).transpose(&[1, 0]),
sync_event: stream.record(),
}
}
@@ -96,11 +96,11 @@ impl<'ctx> LayerParameter<'ctx> {
}
Self {
input_layernorm: map!(input_layernorm),
w_qkv: map!(w_qkv),
self_attn_o_proj: map!(self_attn_o_proj),
w_qkv: map!(w_qkv).transpose(&[1, 0]),
self_attn_o_proj: map!(self_attn_o_proj).transpose(&[1, 0]),
post_attention_layernorm: map!(post_attention_layernorm),
mlp_gate_up: map!(mlp_gate_up),
mlp_down: map!(mlp_down),
mlp_gate_up: map!(mlp_gate_up).transpose(&[1, 0]),
mlp_down: map!(mlp_down).transpose(&[1, 0]),
layer,
sync_event: stream.record(),
}
@@ -114,7 +114,6 @@ impl<'ctx> LayerParameter<'ctx> {
macro_rules! update {
($param:ident) => {
self.$param
.access_mut()
.physical_mut()
.copy_in_async(host.$param(layer).as_slice(), stream)
};
59 changes: 17 additions & 42 deletions transformer-nvidia/src/storage.rs
Original file line number Diff line number Diff line change
@@ -1,75 +1,50 @@
use cuda::{DevMem, DevSlice, Stream};
use std::{
cell::{Ref, RefCell, RefMut},
ops::{Deref, DerefMut},
rc::Rc,
};
use tensor::PhysicalCell;
use tensor::Splitable;

#[derive(Clone)]
pub struct Storage<'ctx>(Rc<RefCell<DevMem<'ctx>>>);
pub struct MemRef<'a, 'ctx: 'a>(Ref<'a, DevMem<'ctx>>);
pub struct MemRefMut<'a, 'ctx: 'a>(RefMut<'a, DevMem<'ctx>>);

impl<'ctx> PhysicalCell for Storage<'ctx> {
type Raw = DevSlice;
type Access<'a> = MemRef<'a, 'ctx> where Self: 'a;
type AccessMut<'a> = MemRefMut<'a, 'ctx> where Self: 'a;

#[inline]
unsafe fn get_unchecked(&self) -> &Self::Raw {
&*self.0.as_ptr()
}

#[inline]
unsafe fn get_unchecked_mut(&mut self) -> &mut Self::Raw {
&mut *self.0.as_ptr()
}
pub struct Storage<'ctx>(Rc<DevMem<'ctx>>);

impl<'ctx> Storage<'ctx> {
#[inline]
fn access(&self) -> Self::Access<'_> {
MemRef(self.0.borrow())
pub fn new(size: usize, stream: &Stream<'ctx>) -> Self {
Self(Rc::new(stream.malloc::<u8>(size)))
}

#[inline]
fn access_mut(&mut self) -> Self::AccessMut<'_> {
MemRefMut(self.0.borrow_mut())
pub unsafe fn borrow(&self) -> Self {
Self(self.0.clone())
}
}

impl<'ctx> Deref for MemRef<'_, 'ctx> {
type Target = DevSlice;
impl<'ctx> From<DevMem<'ctx>> for Storage<'ctx> {
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
fn from(mem: DevMem<'ctx>) -> Self {
Self(Rc::new(mem))
}
}

impl<'ctx> Deref for MemRefMut<'_, 'ctx> {
impl<'ctx> Deref for Storage<'ctx> {
type Target = DevSlice;

#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<'ctx> DerefMut for MemRefMut<'_, 'ctx> {
impl<'ctx> DerefMut for Storage<'ctx> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
unsafe { self.0.get_mut() }
}
}

impl<'ctx> Storage<'ctx> {
#[inline]
pub fn new(len: usize, stream: &Stream<'ctx>) -> Self {
Self(Rc::new(RefCell::new(stream.malloc::<u8>(len))))
}
}

impl<'ctx> From<DevMem<'ctx>> for Storage<'ctx> {
impl<'ctx> Splitable for Storage<'ctx> {
#[inline]
fn from(value: DevMem<'ctx>) -> Self {
Self(Rc::new(RefCell::new(value)))
fn split(&self) -> Self {
Self(self.0.clone())
}
}

0 comments on commit f66bcd2

Please sign in to comment.