Skip to content

Commit

Permalink
refactor(transformer-cpu): 基于 causal-lm 接口重写 transformer 推理
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 23, 2024
1 parent 5435e08 commit ce5569f
Show file tree
Hide file tree
Showing 6 changed files with 450 additions and 11 deletions.
16 changes: 10 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions causal-lm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ authors = ["YdrMaster <[email protected]>"]
[dependencies]
common = { path = "../common" }
tensor = { path = "../tensor" }
rand = "0.8"
half.workspace = true
8 changes: 5 additions & 3 deletions causal-lm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
//! 提供因果语言模型的特性定义。
// #![deny(warnings, missing_docs)]
#![deny(warnings, missing_docs)]

mod query_context;
mod sample;

pub use query_context::QueryContext;
pub use sample::SampleArgs;

use common::{upos, utok};
use tensor::{udim, Tensor};
Expand Down Expand Up @@ -36,7 +38,7 @@ pub trait CausalLM {
hidden_state: Tensor<Self::Storage>,
) -> Tensor<Self::Storage>;
/// 对 logits 进行采样。
fn sample(&self, logits: Tensor<Self::Storage>) -> Vec<utok>;
fn sample(&self, logits: Tensor<Self::Storage>, args: SampleArgs) -> Vec<utok>;
}

/// 解码的要求。
Expand All @@ -50,7 +52,7 @@ pub struct DecodingMeta {
/// 生成位置张量。
#[inline]
pub fn pos<'a, S: 'a>(
queries: impl IntoIterator<Item = QueryContext<'a, S>>,
queries: impl IntoIterator<Item = &'a QueryContext<'a, S>>,
nt_hint: udim,
) -> Tensor<Vec<upos>> {
let mut ans = Vec::with_capacity(nt_hint as usize);
Expand Down
157 changes: 157 additions & 0 deletions causal-lm/src/sample.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#![allow(missing_docs)]

use common::utok;
use std::{cmp::Ordering, collections::BinaryHeap, fmt::Debug};

#[derive(Clone, PartialEq, Debug)]
pub struct SampleArgs {
pub temperature: f32,
pub top_k: usize,
pub top_p: f32,
}

impl SampleArgs {
#[inline]
fn is_argmax(&self) -> bool {
self.temperature <= 0. || self.top_k < 2 || self.top_p <= 0.
}

pub fn random<T>(&self, logits: &[T]) -> utok
where
T: BetweenF32 + PartialOrd,
{
if self.is_argmax() {
return logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0 as _;
}

#[derive(Clone, Copy, PartialEq, Debug)]
struct Probability {
val: f32,
tok: utok,
}
impl Eq for Probability {}
impl PartialOrd for Probability {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Probability {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
match self.val.partial_cmp(&other.val).unwrap() {
Ordering::Equal => self.tok.cmp(&other.tok),
ord => ord.reverse(),
}
}
}
impl<T: BetweenF32> From<(usize, &T)> for Probability {
#[inline]
fn from((i, p): (usize, &T)) -> Self {
Self {
val: p.get(),
tok: i as _,
}
}
}

// top-k & max
let logits = if self.top_k < logits.len() {
let mut buf = BinaryHeap::with_capacity(self.top_k + 1);
for it in logits.iter().enumerate() {
buf.push(Probability::from(it));
if buf.len() > self.top_k {
buf.pop();
}
}
buf.into_vec()
} else {
let mut buf = logits
.iter()
.enumerate()
.map(Probability::from)
.collect::<Vec<_>>();
buf.sort_unstable();
buf
};
let max = logits[0].val;
// temperature & sum
let (logits, sum) = {
let mut logits = logits;
let mut sum = 0.;
for pi in logits.iter_mut() {
pi.val = ((pi.val - max) / self.temperature).exp();
sum += pi.val;
}
(logits, sum)
};
// top p
let logits = if self.top_p < 1. {
let i = logits
.iter()
.scan(self.top_p * sum, |top_p, pi| {
if *top_p > 0. {
*top_p -= pi.val;
Some(())
} else {
None
}
})
.count();
&logits[..i]
} else {
&logits[..]
};
// random
let mut rand = rand::random::<f32>() * sum;
logits
.iter()
.find(|pi| {
rand -= pi.val;
rand <= 0.
})
.unwrap_or(logits.last().unwrap())
.tok
}
}

pub trait BetweenF32 {
fn zero() -> Self;
fn cast(f: f32) -> Self;
fn get(&self) -> f32;
}

impl BetweenF32 for f32 {
#[inline]
fn zero() -> Self {
0.
}
#[inline]
fn cast(f: f32) -> Self {
f
}
#[inline]
fn get(&self) -> f32 {
*self
}
}

impl BetweenF32 for half::f16 {
#[inline]
fn zero() -> Self {
Self::ZERO
}
#[inline]
fn cast(f: f32) -> Self {
Self::from_f32(f)
}
#[inline]
fn get(&self) -> f32 {
Self::to_f32(*self)
}
}
2 changes: 2 additions & 0 deletions transformer-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ authors = ["YdrMaster <[email protected]>"]
[dependencies]
common = { path = "../common" }
tensor = { path = "../tensor" }
causal-lm = { path = "../causal-lm" }
transformer = { path = "../transformer" }
itertools = "0.12"
gemm = "0.17"
intel-mkl-src = { version = "0.8", features = ["mkl-dynamic-lp64-iomp"] }

Expand Down
Loading

0 comments on commit ce5569f

Please sign in to comment.