Skip to content

Commit

Permalink
refactor(transformer): 添加随机采样的接口
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 18, 2024
1 parent e285972 commit d066026
Show file tree
Hide file tree
Showing 12 changed files with 135 additions and 73 deletions.
40 changes: 23 additions & 17 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ resolver = "2"
[workspace.dependencies]
find_cuda_helper = "0.2"
half = "2.4"
rayon = "1.9"
serde_json = "1.0"
serde = "1.0"
log = "0.4"
21 changes: 7 additions & 14 deletions service/src/cpu.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use crate::{argmax, Command};
use crate::Command;
use common::utok;
use half::f16;
use std::{collections::HashMap, path::Path, time::Instant};
use tensor::reslice;
use transformer_cpu::{LayerCache, Memory, Request, Transformer};
use transformer_cpu::{LayerCache, Memory, Request, SampleArgs, Transformer};

pub struct CpuTask {
transformer: Transformer,
Expand Down Expand Up @@ -43,22 +41,17 @@ impl CpuTask {
let eos = self.transformer.eos_token_id();

let time = Instant::now();
let mut logits = self
let mut token = self
.transformer
.decode(vec![ctx.request(&prompt, max_seq_len)])
.decode(vec![ctx.request(&prompt, max_seq_len)], SampleArgs::Top)[0]
.1;
info!("prefill transformer ... {:?}", time.elapsed());

loop {
let token = argmax(reslice::<u8, f16>(logits.as_slice()));
if token == eos {
break;
}
while token != eos {
responsing.send(token).unwrap();

logits = self
token = self
.transformer
.decode(vec![ctx.request(&[token], max_seq_len)])
.decode(vec![ctx.request(&[token], max_seq_len)], SampleArgs::Top)[0]
.1;
}
}
Expand Down
9 changes: 0 additions & 9 deletions service/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,3 @@ impl<Cache> SessionContext<Cache> {
}
}
}

fn argmax<T: PartialOrd>(logits: &[T]) -> utok {
logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0 as _
}
37 changes: 17 additions & 20 deletions service/src/nvidia.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use crate::{argmax, Command};
use crate::Command;
use common::utok;
use half::f16;
use std::{
collections::HashMap, fs::File, io::Read, path::Path, sync::mpsc::Receiver, time::Instant,
};
use tensor::reslice;
use transformer_cpu::{Llama2, Memory};
use transformer_cpu::{Llama2, Memory, SampleArgs};
use transformer_nvidia::{
cuda::{ContextGuard, Stream},
LayerCache, Request, Transformer,
Expand Down Expand Up @@ -49,25 +47,24 @@ pub fn task(model_dir: impl AsRef<Path>, receiver: Receiver<Command>, ctx: &Cont
.or_insert_with_key(|&id| SessionContext::new(&transformer, id, &transfer));

let time = Instant::now();
let mut logits = transformer
.decode(vec![ctx.request(&prompt, max_seq_len)], &compute, &transfer)
.1;
let mut token = transformer.decode(
vec![ctx.request(&prompt, max_seq_len)],
SampleArgs::Top,
&compute,
&transfer,
)[0]
.1;
info!("prefill transformer ... {:?}", time.elapsed());

loop {
let token = argmax(reslice::<u8, f16>(logits.as_slice()));
if token == eos {
break;
}
while token != eos {
responsing.send(token).unwrap();

logits = transformer
.decode(
vec![ctx.request(&[token], max_seq_len)],
&compute,
&transfer,
)
.1;
token = transformer.decode(
vec![ctx.request(&[token], max_seq_len)],
SampleArgs::Top,
&compute,
&transfer,
)[0]
.1;
}
}
Command::Drop { id } => {
Expand Down
4 changes: 2 additions & 2 deletions tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ authors = ["YdrMaster <[email protected]>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
half.workspace = true
smallvec = "1.13"
nalgebra = "0.32"
rayon = "1.9"
half.workspace = true
rayon.workspace = true
serde.workspace = true
4 changes: 2 additions & 2 deletions transformer-cpu/src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ pub(super) use swiglu::swiglu;

macro_rules! slice {
($blob:expr; $width:expr; [$line:expr]) => {
$blob[$line as usize * $width..][..$width]
$blob[$line as usize * $width as usize..][..$width as usize]
};
}

use slice;
pub(super) use slice;
38 changes: 35 additions & 3 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ mod kernel;
mod storage;

use common::utok;
use gemm::f16;
use kernel::{gather, mat_mul, rms_norm, rotary_embedding, softmax, swiglu};
use storage::Storage;
use tensor::{reslice, slice, udim, DataType, Tensor};

pub type Request<'a, Id> = transformer::Request<'a, Id, Storage>;
pub type LayerCache = transformer::LayerCache<Storage>;
pub use transformer::{save, Llama2, Memory};
pub use transformer::{save, Llama2, Memory, SampleArgs};

pub struct Transformer(Box<dyn Llama2>);

Expand All @@ -33,7 +34,11 @@ impl Transformer {
self.0.eos_token_id()
}

pub fn decode<Id>(&mut self, mut requests: Vec<Request<Id>>) -> (Vec<Id>, Tensor<Vec<u8>>) {
pub fn decode<Id>(
&mut self,
mut requests: Vec<Request<Id>>,
sample: SampleArgs,
) -> Vec<(Id, utok)> {
requests.sort_unstable_by_key(|t| t.tokens.len());

// println!("tokens:");
Expand Down Expand Up @@ -226,7 +231,25 @@ impl Transformer {
mat_mul(&mut logits, 0., &x, &lm_head, 1.);
// println!("logits:\n{}", logits.access());

(requests.into_iter().map(|r| r.id).collect(), logits)
let logits: &[f16] = reslice(logits.as_slice());
requests
.into_iter()
.enumerate()
.map(|(i, r)| {
let logits = &kernel::slice!(logits; voc; [i]);
(
r.id,
match sample {
SampleArgs::Top => argmax(logits),
SampleArgs::Random {
temperature: _,
top_k: _,
top_p: _,
} => todo!(),
},
)
})
.collect()
}
}

Expand Down Expand Up @@ -257,3 +280,12 @@ fn test_build() {
let t1 = Instant::now();
println!("build transformer {:?}", t1 - t0);
}

fn argmax<T: PartialOrd>(logits: &[T]) -> utok {
logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0 as _
}
Loading

0 comments on commit d066026

Please sign in to comment.