diff --git a/transformer-nvidia/src/lib.rs b/transformer-nvidia/src/lib.rs index ca10b669..22ea64c8 100644 --- a/transformer-nvidia/src/lib.rs +++ b/transformer-nvidia/src/lib.rs @@ -14,7 +14,7 @@ use cuda::{AsRaw, CudaDataType::half, Stream}; use kernel::{gather, mat_mul, FusedSoftmax, Reform, RmsNormalization, RotaryEmbedding, Swiglu}; use parameters::{LayersParameters, ModelParameters}; use storage::Storage; -use tensor::{reslice, slice, udim, DataType, Tensor}; +use tensor::{slice, udim, DataType, Tensor}; use transformer::SampleArgs; pub type Request<'a, 'b, Id> = transformer::Request<'a, Id, Storage<'b>>; @@ -277,7 +277,6 @@ impl<'ctx> Transformer<'ctx> { // compute.synchronize(); // println!("model norm:\n{}", map_tensor(&x)); - let mut logits = unsafe { logits_dev.as_ref().map_physical(|dev| vec![0; dev.len()]) }; mat_mul( &self.cublas, &mut logits_dev, @@ -286,18 +285,17 @@ impl<'ctx> Transformer<'ctx> { &self.model.lm_head, 1., ); - compute.synchronize(); - logits_dev.physical().copy_out(logits.physical_mut()); - // println!("logits:\n{}", logits); - let logits: &[f16] = reslice(logits.as_slice()); + let mut logits = vec![f16::ZERO; logits_dev.size()]; + compute.synchronize(); + logits_dev.physical().copy_out(&mut logits); requests .into_iter() .enumerate() .map(|(i, r)| { ( r.id, - sample.random(&logits[i * voc as usize..][..voc as usize]), + sample.random(&mut logits[i * voc as usize..][..voc as usize]), ) }) .collect() diff --git a/transformer/src/sample.rs b/transformer/src/sample.rs index fb386839..6f7507d1 100644 --- a/transformer/src/sample.rs +++ b/transformer/src/sample.rs @@ -47,20 +47,36 @@ impl SampleArgs { } } } + impl From<(usize, &f16)> for Probability { + #[inline] + fn from((i, p): (usize, &f16)) -> Self { + Self { + val: p.to_f32(), + tok: i as _, + } + } + } + // top-k & max - let (logits, max) = { + let logits = if self.top_k < logits.len() { let mut buf = BinaryHeap::with_capacity(self.top_k + 1); - let mut max = f32::NEG_INFINITY; - for (i, p) in logits.iter().enumerate() { - let val = p.to_f32(); - max = max.max(val); - buf.push(Probability { val, tok: i as _ }); + for it in logits.iter().enumerate() { + buf.push(Probability::from(it)); if buf.len() > self.top_k { buf.pop(); } } - (buf.into_vec(), max) + buf.into_vec() + } else { + let mut buf = logits + .iter() + .enumerate() + .map(Probability::from) + .collect::>(); + buf.sort_unstable(); + buf }; + let max = logits[0].val; // temperature & sum let (logits, sum) = { let mut logits = logits; diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 88632a0e..deb95649 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -42,14 +42,14 @@ struct InferenceArgs { #[clap(short, long)] model: String, /// Temperature for random sampling. - #[clap(long, default_value = "0.0")] - temperature: f32, + #[clap(long)] + temperature: Option, /// Top-k for random sampling. - #[clap(long, default_value = "usize::MAX")] - top_k: usize, + #[clap(long)] + top_k: Option, /// Top-p for random sampling. - #[clap(long, default_value = "1.0")] - top_p: f32, + #[clap(long)] + top_p: Option, /// Log level, may be "off", "trace", "debug", "info" or "error". #[clap(long)] log: Option, @@ -88,9 +88,9 @@ impl From for Service { Service::load_model( model, SampleArgs { - temperature, - top_k, - top_p, + temperature: temperature.unwrap_or(0.), + top_k: top_k.unwrap_or(usize::MAX), + top_p: top_p.unwrap_or(1.), }, if nvidia { Device::NvidiaGpu(0)