Skip to content

Commit

Permalink
refactor(transformer-nv): 基于 causal-lm 重做接口,支持通过 chat 启动
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 25, 2024
1 parent 07deb88 commit 0abf47a
Show file tree
Hide file tree
Showing 13 changed files with 539 additions and 390 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 11 additions & 3 deletions causal-lm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@ use common::{upos, utok};
use std::path::Path;
use tensor::{udim, Tensor};

/// 模型。
pub trait Model: Sized {
/// 用于模型加载的元数据。
type Meta;
/// 模型加载中可能的错误。
type Error;
/// 从文件系统加载模型。
fn load(model_dir: impl AsRef<Path>, meta: Self::Meta) -> Result<Self, Self::Error>;
}

/// 因果语言模型。
pub trait CausalLM {
pub trait CausalLM: Model {
/// 存储中间结果的类型。
type Storage;
/// 从文件系统加载模型。
fn load(model_dir: impl AsRef<Path>) -> Self;
/// 模型定义的句子结束符。
fn eos_token(&self) -> utok;
/// 创建一个新的缓存(`num_layers x 2 x num_kv_head x max_seq_len x head_dim`)。
Expand Down
2 changes: 1 addition & 1 deletion nvidia/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mod swiglu;

pub use common::{
safe_tensors::{SafeTensors, SafeTensorsError},
test_model, utok,
test_model, upos, utok,
};
pub use tensor::{slice, split, udim, DataType, LocalSplitable, Tensor};

Expand Down
24 changes: 13 additions & 11 deletions nvidia/common/src/rms_norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
// assert BLOCK_SIZE >= blockDim.x
template<unsigned int BLOCK_SIZE, class Tdata>
static __device__ void padding(
Tdata *__restrict__ y_,
Tdata *__restrict__ o_,
unsigned int const stride_o,
Tdata const *__restrict__ x_,
unsigned int const stride_x,
Tdata const *__restrict__ w_,
float const epsilon,
unsigned int const leading_dim) {
auto y = y_ + blockIdx.x * leading_dim + threadIdx.x;
auto x = x_[blockIdx.x * leading_dim + threadIdx.x];
float const epsilon) {
auto o = o_ + blockIdx.x * stride_o + threadIdx.x;
auto x = x_[blockIdx.x * stride_x + threadIdx.x];
auto w = w_[threadIdx.x];

using BlockOp = cub::BlockReduce<float, BLOCK_SIZE>;
Expand All @@ -23,19 +24,20 @@ static __device__ void padding(
}
__syncthreads();

*y = rms * x * w;
*o = rms * x * w;
}

template<unsigned int BLOCK_SIZE, unsigned int ITEMS_PER_THREAD, class Tdata>
static __device__ void folding(
Tdata *__restrict__ y,
Tdata *__restrict__ o,
unsigned int const stride_o,
Tdata const *__restrict__ x,
unsigned int const stride_x,
Tdata const *__restrict__ w,
float const epsilon,
unsigned int const leading_dim,
unsigned int const items_size) {
y += blockIdx.x * leading_dim;
x += blockIdx.x * leading_dim;
o += blockIdx.x * stride_o;
x += blockIdx.x * stride_x;

float thread_data[ITEMS_PER_THREAD];
{
Expand Down Expand Up @@ -66,7 +68,7 @@ static __device__ void folding(
#pragma unroll
for (unsigned int i = 0; i < ITEMS_PER_THREAD; ++i) {
if (auto j = i + threadIdx.x * ITEMS_PER_THREAD; j < items_size) {
y[j] = Tdata(float(rms) * float(thread_data[i]) * float(w[j]));
o[j] = Tdata(float(rms) * float(thread_data[i]) * float(w[j]));
}
}
}
51 changes: 31 additions & 20 deletions nvidia/common/src/rms_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,33 +33,35 @@ impl RmsNormalization {
r#"{RMS_NORMALIZATION}
extern "C" __global__ void {padding}(
{ty_arg} *__restrict__ y,
{ty_arg} *__restrict__ o,
unsigned int const stride_o,
{ty_arg} const *__restrict__ x,
unsigned int const stride_x,
{ty_arg} const *__restrict__ w,
float epsilon,
unsigned int const leading_dim
float epsilon
){{
padding<{block_size}>
(y, x, w, epsilon, leading_dim);
(o, stride_o, x, stride_x, w, epsilon);
}}
extern "C" __global__ void {folding}(
{ty_arg} *__restrict__ y,
{ty_arg} *__restrict__ o,
unsigned int const stride_o,
{ty_arg} const *__restrict__ x,
unsigned int const stride_x,
{ty_arg} const *__restrict__ w,
float epsilon,
unsigned int const leading_dim,
unsigned int const items_size
){{
folding<{block_size}, {items_per_thread}>
(y, x, w, epsilon, leading_dim, items_size);
(o, stride_o, x, stride_x, w, epsilon, items_size);
}}
"#
);

let (ptx, log) = Ptx::compile(code);
if !log.is_empty() {
warn!("{log}");
println!("{log}");
}
Self {
ptx: ptx.unwrap(),
Expand All @@ -73,7 +75,7 @@ extern "C" __global__ void {folding}(
pub fn launch<T, U, V>(
&self,
module: &ModuleSpore,
y: &mut Tensor<T>,
o: &mut Tensor<T>,
x: &Tensor<U>,
w: &Tensor<V>,
epsilon: f32,
Expand All @@ -83,31 +85,40 @@ extern "C" __global__ void {folding}(
U: Deref<Target = [DevByte]>,
V: Deref<Target = [DevByte]>,
{
debug_assert_eq!(x.shape(), y.shape());
let &[row, col] = x.shape() else { panic!() };
debug_assert_eq!(&[col], w.shape());
let &[n, d] = o.shape() else { panic!() };
let dt = o.data_type();

assert_eq!(x.data_type(), dt);
assert_eq!(w.data_type(), dt);
assert_eq!(o.shape(), x.shape());
assert_eq!(w.shape(), &[d]);
assert!(o.contiguous_len() >= 1);
assert!(x.contiguous_len() >= 1);
assert!(w.is_contiguous());

let y_ptr = (y.physical().as_ptr() as isize + y.bytes_offset()) as CUdeviceptr;
let o_ptr = (o.physical().as_ptr() as isize + o.bytes_offset()) as CUdeviceptr;
let x_ptr = (x.physical().as_ptr() as isize + x.bytes_offset()) as CUdeviceptr;
let w_ptr = (w.physical().as_ptr() as isize + w.bytes_offset()) as CUdeviceptr;
let leading_dim = x.strides()[0] as udim;
let items_len = col as udim;
let params: [*const c_void; 6] = [
(&y_ptr) as *const _ as _,
let stride_o = o.strides()[0] as usize;
let stride_x = x.strides()[0] as usize;
let items_len = d as udim;
let params: [*const c_void; 7] = [
(&o_ptr) as *const _ as _,
(&stride_o) as *const _ as _,
(&x_ptr) as *const _ as _,
(&stride_x) as *const _ as _,
(&w_ptr) as *const _ as _,
(&epsilon) as *const _ as _,
(&leading_dim) as *const _ as _,
(&items_len) as *const _ as _,
];
let module = unsafe { module.sprout(stream.ctx()) };
if items_len <= self.block_size {
let kernel = module.get_kernel(&self.padding);
kernel.launch(row, items_len, params.as_ptr(), 0, Some(stream));
kernel.launch(n, items_len, params.as_ptr(), 0, Some(stream));
} else {
let block_size = (items_len + self.items_per_thread - 1) / self.items_per_thread;
let kernel = module.get_kernel(&self.folding);
kernel.launch(row, block_size, params.as_ptr(), 0, Some(stream));
kernel.launch(n, block_size, params.as_ptr(), 0, Some(stream));
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions nvidia/transformer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ authors = ["YdrMaster <[email protected]>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
causal-lm = { path = "../../causal-lm" }
transformer = { path = "../../transformer" }
itertools = "0.12"
common-nv = { path = "../common" }
half.workspace = true

Expand Down
Loading

0 comments on commit 0abf47a

Please sign in to comment.