Skip to content

Commit

Permalink
feat(transformer): 抽取 transformer 平台无关部分到 crate
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 4, 2024
1 parent 1822036 commit 7511dbe
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 102 deletions.
23 changes: 13 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"tensor",
"model-parameters",
"tokenizer",
"transformer",
"transformer-cpu",
"transformer-nvidia",
"xtask",
Expand Down
1 change: 1 addition & 0 deletions transformer-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ authors = ["YdrMaster <[email protected]>"]
common = { path = "../common" }
tensor = { path = "../tensor" }
model-parameters = { path = "../model-parameters" }
transformer = { path = "../transformer" }
gemm = "0.17"

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion transformer-cpu/src/kernel/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::Request;
use std::ops::{Deref, DerefMut};
use tensor::{udim, Tensor};

pub fn gather<T, U>(mut x: Tensor<T>, table: &Tensor<U>, requests: &[Request])
pub fn gather<T, U, X>(mut x: Tensor<T>, table: &Tensor<U>, requests: &[Request<X>])
where
T: DerefMut<Target = [u8]>,
U: Deref<Target = [u8]>,
Expand Down
43 changes: 4 additions & 39 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,55 +1,20 @@
mod cache;
mod kernel;
mod storage;

use common::{upos, utok};
use gemm::f16;
use kernel::{gather, mat_mul, rms_norm, rms_norm_inplace, rotary_embedding, softmax, swiglu};
use model_parameters::{Llama2, Memory};
use storage::Storage;
use tensor::{reslice, reslice_mut, slice, udim, DataType, Tensor};

pub use cache::LayerCache;
pub type LayerCache = transformer::LayerCache<Storage>;
pub use transformer::{Prompt, Request};
pub extern crate model_parameters;

pub struct Transformer {
model: Box<dyn Llama2>,
}

pub struct Request<'a> {
pub prompt: Prompt<'a>,
pub cache: &'a mut [LayerCache],
pub pos: upos,
}

pub enum Prompt<'a> {
Prefill(&'a [utok]),
Decode(utok),
}

impl Request<'_> {
#[inline]
pub const fn tokens(&self) -> &[utok] {
match &self.prompt {
Prompt::Prefill(tokens) => tokens,
Prompt::Decode(token) => std::slice::from_ref(&token),
}
}

#[inline]
pub const fn seq_len(&self) -> udim {
match self.prompt {
Prompt::Prefill(tokens) => tokens.len() as _,
Prompt::Decode(_) => 1,
}
}

#[inline]
pub const fn att_len(&self) -> udim {
self.pos + self.seq_len()
}
}

impl Transformer {
#[inline]
pub fn new(model: Box<dyn Llama2>) -> Self {
Expand All @@ -63,15 +28,15 @@ impl Transformer {

#[inline]
pub fn new_cache(&self) -> Vec<LayerCache> {
LayerCache::new_layers(&*self.model)
LayerCache::new_layers(&*self.model, tensor)
}

#[inline]
pub fn max_seq_len(&self) -> usize {
self.model.max_position_embeddings()
}

pub fn decode(&mut self, mut requests: Vec<Request>) -> Vec<f16> {
pub fn decode(&mut self, mut requests: Vec<Request<Storage>>) -> Vec<f16> {
use std::cmp::Ordering::*;
requests.sort_unstable_by(|a, b| match a.prompt {
Prompt::Prefill(_) => match b.prompt {
Expand Down
1 change: 1 addition & 0 deletions transformer-nvidia/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ authors = ["YdrMaster <[email protected]>"]
common = { path = "../common" }
tensor = { path = "../tensor" }
model-parameters = { path = "../model-parameters" }
transformer = { path = "../transformer" }
cuda = { git = "https://github.com/YdrMaster/cuda-bench" }
cublas = { git = "https://github.com/YdrMaster/cuda-bench" }
half.workspace = true
Expand Down
32 changes: 0 additions & 32 deletions transformer-nvidia/src/cache.rs

This file was deleted.

15 changes: 7 additions & 8 deletions transformer-nvidia/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#![cfg(detected_cuda)]

mod cache;
mod kernel;
mod page_locked_memory;
mod parameters;
mod storage;

use ::half::f16;
use common::{upos, utok};
Expand All @@ -15,8 +14,8 @@ use parameters::{LayersParameters, ModelParameters};
use std::ptr::null_mut;
use tensor::{slice, udim, DataType, Tensor};

pub use cache::LayerCache;
pub use storage::PageLockedMemory;
pub type LayerCache<'a> = transformer::LayerCache<LocalDevBlob<'a>>;
pub use page_locked_memory::PageLockedMemory;
pub extern crate cuda;
pub extern crate model_parameters;

Expand Down Expand Up @@ -73,13 +72,13 @@ impl<'a> Transformer<'a> {

#[inline]
pub fn new_cache<'b>(&self, stream: &'b Stream) -> Vec<LayerCache<'b>> {
LayerCache::new_layers(self.host, stream)
LayerCache::new_layers(self.host, |dt, shape| tensor(dt, shape, stream))
}

pub fn update<'b>(
&mut self,
tokens: &[utok],
cache: &[LayerCache],
cache: &mut [LayerCache],
pos: upos,
compute: &Stream,
transfer: &'b Stream,
Expand Down Expand Up @@ -128,7 +127,7 @@ impl<'a> Transformer<'a> {

cublas!(cublasSetStream_v2(self.cublas, compute.as_raw() as _));
compute.wait_for(&e_alloc);
for (layer, cache) in cache.iter().enumerate() {
for (layer, cache) in cache.iter_mut().enumerate() {
self.layers.load(layer, self.host, transfer);
let params = self.layers.sync(layer, compute);

Expand Down Expand Up @@ -248,7 +247,7 @@ impl<'a> Transformer<'a> {
pub fn decode(
&mut self,
token: utok,
cache: &[LayerCache],
cache: &mut [LayerCache],
pos: upos,
compute: &Stream,
transfer: &Stream,
Expand Down
File renamed without changes.
12 changes: 12 additions & 0 deletions transformer/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "transformer"
version = "0.0.0"
edition = "2021"
authors = ["YdrMaster <[email protected]>"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
common = { path = "../common" }
tensor = { path = "../tensor" }
model-parameters = { path = "../model-parameters" }
25 changes: 16 additions & 9 deletions transformer-cpu/src/cache.rs → transformer/src/cache.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
use crate::{tensor, Storage};
use model_parameters::Llama2;
use tensor::{udim, Tensor};
use model_parameters::Llama2;
use tensor::{udim, DataType, Tensor};

pub struct LayerCache {
/// KV cache for one layer.
pub struct LayerCache<Storage> {
/// Key cache, shape = `num_kv_head x max_seq_len x head_dim`.
k: Tensor<Storage>,
/// Value cache, shape = `num_kv_head x max_seq_len x head_dim`.
v: Tensor<Storage>,
}

impl LayerCache {
pub fn new_layers(model: &dyn Llama2) -> Vec<Self> {
let dt = model.data_type();
impl<Storage> LayerCache<Storage> {
/// Alloc KV Cache for all layers.
pub fn new_layers(
model: &dyn Llama2,
tensor: impl Fn(DataType, &[udim]) -> Tensor<Storage>,
) -> Vec<Self> {
let nkvh = model.num_key_value_heads() as udim;
let hd = (model.hidden_size() / model.num_attention_heads()) as udim;
let max_seq_len = model.max_position_embeddings() as udim;
let shape = &[nkvh, max_seq_len, hd];
let dh = (model.hidden_size() / model.num_attention_heads()) as udim;

let dt = model.data_type();
let shape = &[nkvh, max_seq_len, dh];

(0..model.num_hidden_layers())
.map(|_| Self {
k: tensor(dt, shape),
Expand All @@ -24,6 +30,7 @@ impl LayerCache {
.collect()
}

/// Get mutable references to the key and value cache.
#[inline]
pub fn get(&mut self) -> (&mut Tensor<Storage>, &mut Tensor<Storage>) {
(&mut self.k, &mut self.v)
Expand Down
54 changes: 54 additions & 0 deletions transformer/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//! Common code for transformers.
#![deny(warnings, missing_docs)]

mod cache;

use common::{upos, utok};
use tensor::udim;

pub use cache::LayerCache;

/// A request to decode a sequence.
pub struct Request<'a, Storage> {
/// Prompt of this request.
pub prompt: Prompt<'a>,
/// Context cache of this request.
pub cache: &'a mut [LayerCache<Storage>],
/// Position of `prompt` in context.
pub pos: upos,
}

/// User prompt in transformer inference once.
pub enum Prompt<'a> {
/// Prefill the sequence with tokens.
Prefill(&'a [utok]),
/// Decode the next token.
Decode(utok),
}

impl<S> Request<'_, S> {
/// Tokens in the prompt.
#[inline]
pub const fn tokens(&self) -> &[utok] {
match &self.prompt {
Prompt::Prefill(tokens) => tokens,
Prompt::Decode(token) => std::slice::from_ref(&token),
}
}

/// Length of tokens in the prompt.
#[inline]
pub const fn seq_len(&self) -> udim {
match self.prompt {
Prompt::Prefill(tokens) => tokens.len() as _,
Prompt::Decode(_) => 1,
}
}

/// Length of tokens in attention computation.
#[inline]
pub const fn att_len(&self) -> udim {
self.pos + self.seq_len()
}
}
6 changes: 3 additions & 3 deletions xtask/src/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ fn on_nvidia_gpu(
let host = Memory::load_safetensors(config, host, false).unwrap();
let eos = host.eos_token_id();
let mut transformer = Transformer::new(&host, preload_layers, &transfer);
let kv_cache = transformer.new_cache(&compute);
let mut kv_cache = transformer.new_cache(&compute);
info!("build model host: {:?}", time.elapsed());

let step = step.min(host.max_position_embeddings());
Expand All @@ -264,7 +264,7 @@ fn on_nvidia_gpu(
let time = Instant::now();
let (last, tokens) = prompt_tokens.split_last().expect("prompt is empty");
if !tokens.is_empty() {
transformer.update(tokens, &kv_cache, 0, &compute, &transfer);
transformer.update(tokens, &mut kv_cache, 0, &compute, &transfer);
}
info!("prefill transformer ... {:?}", time.elapsed());

Expand All @@ -274,7 +274,7 @@ fn on_nvidia_gpu(
let mut pos = tokens.len();
let time = Instant::now();
while pos < step {
let logits = transformer.decode(token, &kv_cache, pos as _, &compute, &transfer);
let logits = transformer.decode(token, &mut kv_cache, pos as _, &compute, &transfer);
token = argmax(logits);

print!("{}", tokenizer.decode(token).replace('▁', " "));
Expand Down

0 comments on commit 7511dbe

Please sign in to comment.