Skip to content

Commit

Permalink
temp save
Browse files Browse the repository at this point in the history
  • Loading branch information
kilinchange committed May 6, 2024
1 parent 683e3be commit 12baabb
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 223 deletions.
4 changes: 0 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ serde = "1.0"
log = "0.4"
tokio = { version = "1.37", features = ["rt-multi-thread", "sync"] }

cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "96833d2" }
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "96833d2" }
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "96833d2" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "96833d2" }
cuda = { path = "../cuda-driver/cuda" }
cublas = { path = "../cuda-driver/cublas" }
nccl = { path = "../cuda-driver/nccl" }
search-cuda-tools = { path = "../cuda-driver/search-cuda-tools" }
38 changes: 36 additions & 2 deletions nvidia/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@ mod swiglu;
mod paged_attention;

pub use common::utok;
pub use tensor::{slice, udim, DataType, LocalSplitable, Tensor};
pub use tensor::{slice, udim,reslice, DataType, LocalSplitable, Tensor};

use cublas::{Cublas, CublasSpore};
use cuda::{
memcpy_d2h, ContextGuard, ContextResource, ContextSpore, CudaDataType::f16, DevByte,
memcpy_d2h, ContextGuard, ContextResource, ContextSpore, CudaDataType::f16, CudaDataType::u16, DevByte,
ModuleSpore, Ptx, Stream,
};
use fused_softmax::FusedSoftmax;
use reform::Reform;
use rms_norm::RmsNormalization;
use rotary_embedding::Rope;
use paged_attention::PagedAttention;
use std::{
ops::{Deref, DerefMut},
sync::Arc,
Expand All @@ -40,6 +41,7 @@ pub struct NvidiaKernelsPtx {
reform: Arc<Reform>,
softmax: Arc<FusedSoftmax>,
swiglu: Arc<Swiglu>,
paged_attention: Arc<PagedAttention>,
}

impl NvidiaKernelsPtx {
Expand All @@ -56,6 +58,7 @@ impl NvidiaKernelsPtx {
block_size,
)),
swiglu: Arc::new(Swiglu::new(f16, block_size)),
paged_attention: Arc::new(PagedAttention::new(host.num_attention_heads() as u32, 1, 128, 32, u16, u16)),
}
}
}
Expand Down Expand Up @@ -85,6 +88,7 @@ pub struct NvidiaKernels {
reform: ModuleWapper<Reform>,
softmax: ModuleWapper<FusedSoftmax>,
swiglu: ModuleWapper<Swiglu>,
paged_attention: ModuleWapper<PagedAttention>,
}

impl NvidiaKernelsPtx {
Expand All @@ -98,6 +102,7 @@ impl NvidiaKernelsPtx {
reform: self.reform.clone().load(ctx),
softmax: self.softmax.clone().load(ctx),
swiglu: self.swiglu.clone().load(ctx),
paged_attention: self.paged_attention.clone().load(ctx),
}
}
}
Expand All @@ -111,6 +116,7 @@ impl NvidiaKernels {
self.reform.module.kill(ctx);
self.softmax.module.kill(ctx);
self.swiglu.module.kill(ctx);
self.paged_attention.module.kill(ctx);
}
}
}
Expand All @@ -130,6 +136,32 @@ impl NvidiaKernels {
}
}

impl KernelRuntime<'_>{
#[inline]
pub fn paged_attention<OutT, QT, KT, VT, BlockTablesT, SeqLensT>(
&self,
out: &mut Tensor<OutT>,
query: &Tensor<QT>,
key_cache: &Tensor<KT>,
value_cache: &Tensor<VT>,
num_kv_heads: u32,
scale: f32,
block_tables: &Tensor<BlockTablesT>,
seq_lens: &Tensor<SeqLensT>,
max_seq_len: u32,
) where
OutT: DerefMut<Target = [DevByte]>,
QT: Deref<Target = [DevByte]>,
KT: Deref<Target = [DevByte]>,
VT: Deref<Target = [DevByte]>,
BlockTablesT: Deref<Target = [DevByte]>,
SeqLensT: Deref<Target = [DevByte]>,
{
let ModuleWapper { module, kernel } = &self.kernels.paged_attention;
kernel.launch(module, out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, max_seq_len, self.stream);
}
}

impl Kernels for KernelRuntime<'_> {
type Storage = [DevByte];

Expand Down Expand Up @@ -210,6 +242,8 @@ impl Kernels for KernelRuntime<'_> {
let ModuleWapper { module, kernel } = &self.kernels.swiglu;
kernel.launch(module, gate, up, self.stream);
}


}

#[allow(unused)]
Expand Down
Loading

0 comments on commit 12baabb

Please sign in to comment.