Skip to content

Commit

Permalink
perf(service): 优化无 top k 剪枝情况下的随机采样
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 18, 2024
1 parent 1784077 commit 8d69e49
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 23 deletions.
12 changes: 5 additions & 7 deletions transformer-nvidia/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use cuda::{AsRaw, CudaDataType::half, Stream};
use kernel::{gather, mat_mul, FusedSoftmax, Reform, RmsNormalization, RotaryEmbedding, Swiglu};
use parameters::{LayersParameters, ModelParameters};
use storage::Storage;
use tensor::{reslice, slice, udim, DataType, Tensor};
use tensor::{slice, udim, DataType, Tensor};
use transformer::SampleArgs;

pub type Request<'a, 'b, Id> = transformer::Request<'a, Id, Storage<'b>>;
Expand Down Expand Up @@ -277,7 +277,6 @@ impl<'ctx> Transformer<'ctx> {
// compute.synchronize();
// println!("model norm:\n{}", map_tensor(&x));

let mut logits = unsafe { logits_dev.as_ref().map_physical(|dev| vec![0; dev.len()]) };
mat_mul(
&self.cublas,
&mut logits_dev,
Expand All @@ -286,18 +285,17 @@ impl<'ctx> Transformer<'ctx> {
&self.model.lm_head,
1.,
);
compute.synchronize();
logits_dev.physical().copy_out(logits.physical_mut());
// println!("logits:\n{}", logits);

let logits: &[f16] = reslice(logits.as_slice());
let mut logits = vec![f16::ZERO; logits_dev.size()];
compute.synchronize();
logits_dev.physical().copy_out(&mut logits);
requests
.into_iter()
.enumerate()
.map(|(i, r)| {
(
r.id,
sample.random(&logits[i * voc as usize..][..voc as usize]),
sample.random(&mut logits[i * voc as usize..][..voc as usize]),
)
})
.collect()
Expand Down
30 changes: 23 additions & 7 deletions transformer/src/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,36 @@ impl SampleArgs {
}
}
}
impl From<(usize, &f16)> for Probability {
#[inline]
fn from((i, p): (usize, &f16)) -> Self {
Self {
val: p.to_f32(),
tok: i as _,
}
}
}

// top-k & max
let (logits, max) = {
let logits = if self.top_k < logits.len() {
let mut buf = BinaryHeap::with_capacity(self.top_k + 1);
let mut max = f32::NEG_INFINITY;
for (i, p) in logits.iter().enumerate() {
let val = p.to_f32();
max = max.max(val);
buf.push(Probability { val, tok: i as _ });
for it in logits.iter().enumerate() {
buf.push(Probability::from(it));
if buf.len() > self.top_k {
buf.pop();
}
}
(buf.into_vec(), max)
buf.into_vec()
} else {
let mut buf = logits
.iter()
.enumerate()
.map(Probability::from)
.collect::<Vec<_>>();
buf.sort_unstable();
buf
};
let max = logits[0].val;
// temperature & sum
let (logits, sum) = {
let mut logits = logits;
Expand Down
18 changes: 9 additions & 9 deletions xtask/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ struct InferenceArgs {
#[clap(short, long)]
model: String,
/// Temperature for random sampling.
#[clap(long, default_value = "0.0")]
temperature: f32,
#[clap(long)]
temperature: Option<f32>,
/// Top-k for random sampling.
#[clap(long, default_value = "usize::MAX")]
top_k: usize,
#[clap(long)]
top_k: Option<usize>,
/// Top-p for random sampling.
#[clap(long, default_value = "1.0")]
top_p: f32,
#[clap(long)]
top_p: Option<f32>,
/// Log level, may be "off", "trace", "debug", "info" or "error".
#[clap(long)]
log: Option<String>,
Expand Down Expand Up @@ -88,9 +88,9 @@ impl From<InferenceArgs> for Service {
Service::load_model(
model,
SampleArgs {
temperature,
top_k,
top_p,
temperature: temperature.unwrap_or(0.),
top_k: top_k.unwrap_or(usize::MAX),
top_p: top_p.unwrap_or(1.),
},
if nvidia {
Device::NvidiaGpu(0)
Expand Down

0 comments on commit 8d69e49

Please sign in to comment.