Skip to content

Commit

Permalink
todo: 初步实现 cpu 推理
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Sep 27, 2024
1 parent 6462640 commit 0ecabb1
Show file tree
Hide file tree
Showing 10 changed files with 742 additions and 20 deletions.
406 changes: 400 additions & 6 deletions Cargo.lock

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
[workspace]
members = ["gguf", "tensor", "causal-lm", "models/llama/common", "test-utils"]
members = [
"gguf",
"tensor",
"causal-lm",
"models/llama/common",
"models/llama/common-cpu",
"test-utils",
]
resolver = "2"

[workspace.dependencies]
Expand Down
18 changes: 18 additions & 0 deletions models/llama/common-cpu/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "llama-cpu"
version = "0.0.0"
edition = "2021"
authors = ["YdrMaster <[email protected]>"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
llama.path = "../common"
operators = { workspace = true, features = ["common-cpu"] }
memmap2.workspace = true
tensor.workspace = true
half = "2.4"

[dev-dependencies]
test-utils.workspace = true
gguf.workspace = true
203 changes: 203 additions & 0 deletions models/llama/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
use half::f16;
use llama::{
BlkWeight, LlamaArgs, LlamaBlkStorage, LlamaBlks, LlamaMeta, LlamaRequest, LlamaStorage,
RandomSample, WeightLoader,
};
use memmap2::Mmap;
use operators::{
common_cpu::{Cpu, ThisThread},
random_sample::{common_cpu::Operator as CpuOp, KVPair, SampleArgs},
QueueOf,
};
use std::slice::from_raw_parts_mut;
use tensor::Tensor;

pub struct Llama {
_storage: Box<[Mmap]>,
token_embed: &'static [u8],
single: LlamaBlks<Cpu, Weights, Operators>,
sample: RandomSample<Cpu, CpuOp>,
}

impl Llama {
pub fn new(_storage: Box<[Mmap]>, model: LlamaStorage<&'static [u8]>) -> Self {
let LlamaStorage {
meta,
token_embed,
output_norm,
output,
blocks,
} = model;
assert_eq!(meta.distribute, 1);
assert!(meta.dt_mat.nbytes().is_some());
Self {
_storage,
token_embed,
single: LlamaBlks::new(
&Cpu,
meta,
Weights {
blks: blocks,
output_norm,
output,
},
),
sample: RandomSample::new(&Cpu),
}
}

pub fn infer(&self, input: &[u32], cache: &mut [u8], pos: usize) -> u32 {
let meta = self.single.meta();
let &LlamaMeta {
dt_mat: element,
dctx,
dh,
..
} = meta;
let cache = meta.kv_cache(dctx, cache);
let embd = meta.embd(input.len(), ());
let logits = meta.logits(1, ());

let ele = element.nbytes().unwrap();
let mut embd_buf = vec![0u8; embd.shape().iter().product::<usize>() * ele];
let mut logits_buf = vec![0u8; logits.shape().iter().product::<usize>() * ele];

let d = embd.shape()[1];
for (i, &tok) in input.iter().enumerate() {
embd_buf[i * d..][..d].copy_from_slice(&self.token_embed[tok as usize * d..][..d]);
}

let mut logits = logits.map(|()| &mut *logits_buf);

self.single
.launch(
LlamaArgs {
embd: embd.map(|()| &mut *embd_buf),
logits: logits.map_slice_mut(),
sin: Tensor::new(element, &[0, dh], &[]),
cos: Tensor::new(element, &[0, dh], &[]),
requests: vec![LlamaRequest {
cache,
seq_len: input.len(),
out_len: 1,
pos,
}],
num_tokens: input.len(),
max_seq_len: input.len(),
max_att_len: pos + input.len(),
mlp_alpha: 1.,
},
&mut [],
&ThisThread,
)
.unwrap();

let mut pair = KVPair::new(0, f16::ZERO);
let mut pairs = Tensor::new(KVPair::<()>::LAYOUT, &[], unsafe {
from_raw_parts_mut(&mut pair as *mut _ as *mut u8, size_of_val(&pair))
});

self.sample
.launch(
&mut pairs,
&logits,
&Tensor::new(element, &[], &[0u8; 0][..]),
SampleArgs::ARG_MAX,
&mut [],
&ThisThread,
)
.unwrap();

pair.idx() as u32
}
}

struct Operators;

macro_rules! op {
($name:ident) => {
operators::$name::common_cpu::Operator
};
}

impl llama::Operators for Operators {
type Hardware = Cpu;
type RmsNorm = op!(rms_norm);
type MatMul = op!(mat_mul);
type Rope = op!(rope);
type AttnKVCached = op!(attention_kv_cached);
type Mlp = op!(mlp);
type Rearrange = op!(rearrange);
}

struct Weights {
blks: Box<[LlamaBlkStorage<&'static [u8]>]>,
output_norm: &'static [u8],
output: &'static [u8],
}

impl WeightLoader for Weights {
type Hardware = Cpu;
type Memory = &'static [u8];

#[inline]
fn load_blk(
&self,
which: BlkWeight,
iblk: usize,
_queue: &QueueOf<Self::Hardware>,
) -> Self::Memory {
let blk = &self.blks[iblk];
match which {
BlkWeight::AttnNorm => blk.attn_norm,
BlkWeight::AttnQKV => blk.attn_qkv,
BlkWeight::AttnO => blk.attn_o,
BlkWeight::FfnNorm => blk.ffn_norm,
BlkWeight::FfnGateUp => blk.ffn_gate_up,
BlkWeight::FfnDown => blk.ffn_down,
}
}

#[inline]
fn output_norm(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Memory {
self.output_norm
}

#[inline]
fn output(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Memory {
self.output
}
}

#[test]
fn test_load() {
use gguf::GGufModel;
use std::{io::Write, slice::from_raw_parts};

let Some(shards) = test_utils::map_gguf_files() else {
return;
};
let gguf = GGufModel::read(shards.iter().map(|s| &**s));
let tokenizer = gguf.tokenizer();
let llama =
LlamaStorage::from_gguf(&gguf).map(&mut |s| unsafe { from_raw_parts(s.as_ptr(), s.len()) });
let llama = Llama::new(shards, llama);

let meta = llama.single.meta();
println!("{meta:?}");

let cache = meta.kv_cache(meta.dctx, ());
let mut cache_buf = vec![0u8; cache.shape().iter().product::<usize>() * size_of::<f16>()];

let mut prompt = "Once upon a time,".to_string();
let mut tokens = tokenizer.encode(&prompt);
while !tokens.contains(&2) {
let next = llama.infer(&tokens, &mut cache_buf, 0);
tokens = vec![next];

let piece = tokenizer.decode(next);
print!("{piece}");
std::io::stdout().flush().unwrap();
prompt.push_str(&piece);
}
}
1 change: 1 addition & 0 deletions models/llama/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ gguf.workspace = true
operators.workspace = true
tensor.workspace = true
itertools = "0.13"
rand = "0.8"

[dev-dependencies]
test-utils.workspace = true
7 changes: 3 additions & 4 deletions models/llama/common/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ use tensor::Tensor;
pub struct Args<'a, H: Hardware> {
/// shape: [nt, nh x dh]
pub embd: Tensor<&'a mut [H::Byte]>,
/// shape: [n_out, dvoc]
pub logits: Tensor<&'a mut [H::Byte]>,

/// shape: [_, dh]
pub sin: Tensor<&'a [H::Byte]>,
/// shape: [_, dh]
pub cos: Tensor<&'a [H::Byte]>,
/// shape: [n_out, dvoc]
pub logits: Tensor<&'a mut [H::Byte]>,

pub requests: Vec<Request<'a, H>>,

Expand All @@ -18,13 +19,11 @@ pub struct Args<'a, H: Hardware> {
pub max_att_len: usize,

pub mlp_alpha: f32,
pub residual: bool,
}

pub struct Request<'a, H: Hardware> {
/// shape: [buf, nblk, 2, nkvh, dh]
pub cache: Tensor<&'a mut [H::Byte]>,
pub buf_len: usize,
pub seq_len: usize,
pub out_len: usize,
pub pos: usize,
Expand Down
17 changes: 12 additions & 5 deletions models/llama/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ where
rearrange: Ops::Rearrange::new(processor),
}
}

#[inline]
pub const fn meta(&self) -> &LlamaMeta {
&self.meta
}
}

impl<H, W, Ops> LlamaBlks<H, W, Ops>
Expand Down Expand Up @@ -107,7 +112,6 @@ where
max_seq_len,
max_att_len,
mlp_alpha,
residual,
} = args;
let LlamaMeta {
dt_mat,
Expand Down Expand Up @@ -171,7 +175,10 @@ where
let mut q = q;
let mut k = k;
let v = v;
let o = x1.as_mut().map(|t| &mut t[..]);
let o = x1
.as_mut()
.tile(1, &[nh * distribute, dh])
.map(|t| &mut t[..]);

self.rope(&mut q, &pos, &sin, &cos, workspace, queue_alloc)?;
self.rope(&mut k, &pos, &sin, &cos, workspace, queue_alloc)?;
Expand Down Expand Up @@ -219,7 +226,7 @@ where
self.rms_norm(&mut x1, &x, &w, workspace, queue_alloc)?;

#[rustfmt::skip]
self.mlp(&mut x, &x1, iblk, mlp_alpha, residual, workspace, queue_alloc)?;
self.mlp(&mut x, &x1, iblk, mlp_alpha, true, workspace, queue_alloc)?;

if distribute > 1 {
todo!("all reduce")
Expand All @@ -234,7 +241,7 @@ where
src += req.seq_len;
for src in src - req.out_len..src {
if src != dst {
let src = unsafe { x.borrow_raw() }.index(0, src);
let src = unsafe { x.map_slice_static() }.index(0, src);
let mut dst = x.map_slice_mut().index(0, dst);
self.rearrange(&mut dst, &src, workspace, queue_alloc)?;
}
Expand All @@ -244,7 +251,7 @@ where
assert_eq!(dst, logits.shape()[0]);

let w = self.weights.output_norm(queue);
let x_ = unsafe { x.borrow_raw() };
let x_ = unsafe { x.map_slice_static() };
let mut x = x.map_slice_mut().slice(0, 0, 1, dst);
self.rms_norm(&mut x, &x_, &w, workspace, queue_alloc)?;

Expand Down
25 changes: 22 additions & 3 deletions models/llama/common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
mod args;
mod compute;
mod random_sample;
mod storage;

use ggus::ggml_quants::digit_layout::DigitLayout;
use tensor::Tensor;

pub use args::{Args as LlamaArgs, Request as LlamaRequest};
pub use compute::{BlkWeight, LlamaBlks, Operators, WeightLoader};
pub use random_sample::RandomSample;
pub use storage::{BlkStorage as LlamaBlkStorage, Storage as LlamaStorage};

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -37,9 +40,25 @@ impl LlamaMeta {
Tensor::new(dt_mat, &[buf, nblk, 2, nkvh / distribute, dh], p)
}

pub fn embd<T>(&self, nt: usize, p: T) -> Tensor<T> {
let &Self { dt_mat, nh, dh, .. } = self;
Tensor::new(dt_mat, &[nt, nh * dh], p)
}

pub fn logits<T>(&self, nt: usize, p: T) -> Tensor<T> {
let &Self { dt_mat, dvoc, .. } = self;
Tensor::new(dt_mat, &[nt, dvoc], p)
}

pub fn token_embd<T>(&self, p: T) -> Tensor<T> {
let &Self { nh, dh, dvoc, .. } = self;
self.mat(p, dvoc, nh * dh, false)
let &Self {
dt_mat,
nh,
dh,
dvoc,
..
} = self;
Tensor::new(dt_mat, &[dvoc, nh * dh], p)
}

pub fn attn_norm<T>(&self, p: T) -> Tensor<T> {
Expand Down Expand Up @@ -103,7 +122,7 @@ impl LlamaMeta {
}

pub fn output<T>(&self, p: T) -> Tensor<T> {
self.token_embd(p)
self.token_embd(p).transpose(&[1, 0])
}

fn norm<T>(&self, p: T) -> Tensor<T> {
Expand Down
Loading

0 comments on commit 0ecabb1

Please sign in to comment.