-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3d8b052
commit db22ce2
Showing
10 changed files
with
1,806 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
[package] | ||
name = "minicpm3-cpu" | ||
version = "0.0.0" | ||
edition = "2021" | ||
authors = ["onenewcode <[email protected]>", "YdrMaster <[email protected]>"] | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
minicpm3.path = "../common" | ||
common.workspace = true | ||
operators = { workspace = true, features = ["common-cpu"] } | ||
|
||
[dev-dependencies] | ||
test-utils.workspace = true | ||
gguf.workspace = true | ||
regex.workspace = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
use crate::{Operators, RandomSample, Weights}; | ||
use common::Distribution; | ||
use gguf::GGufModel; | ||
use minicpm3::{ext::ggml_quants::f16, MiniCPM3Request, MiniCPM3Storage, Minicpm3Worker, Tensor}; | ||
use operators::{ | ||
all_reduce::common_cpu::Operator as AllReduce, | ||
common_cpu::{InprocNode, ThisThread}, | ||
random_sample::{KVPair, SampleArgs}, | ||
Blob, | ||
}; | ||
use regex::Regex; | ||
use std::{ | ||
iter::zip, | ||
ptr::copy_nonoverlapping, | ||
slice::from_raw_parts_mut, | ||
sync::{Arc, Barrier}, | ||
thread, | ||
}; | ||
use test_utils::{test_infer_paralle, Inference, Task, TokenizerAndPrompt, WorkerSeed}; | ||
|
||
type Worker<'w> = Minicpm3Worker<Operators<InprocNode<usize>, AllReduce>, Weights<'w>>; | ||
|
||
#[test] | ||
fn test_infer() { | ||
std::env::set_var( | ||
"TEST_MODEL", | ||
"/home/ztf/cpm/Origin-MiniCPM3-4B-v0.0-F16.gguf", | ||
); | ||
let Some(Inference { | ||
model, | ||
devices, | ||
mut prompt, | ||
as_user, | ||
temperature, | ||
top_p, | ||
top_k, | ||
max_steps, | ||
}) = Inference::load() | ||
else { | ||
return; | ||
}; | ||
prompt = "我".to_owned(); | ||
let gguf = GGufModel::read(model.iter().map(|s| &**s)); | ||
|
||
let TokenizerAndPrompt { | ||
eos, | ||
tokenizer, | ||
prompt, | ||
} = TokenizerAndPrompt::new(&gguf, prompt, as_user); | ||
|
||
let model = MiniCPM3Storage::from_gguf(&gguf); | ||
println!("{:?}", model.meta); | ||
|
||
let sample_args = SampleArgs::new(temperature, top_p, top_k).expect("invalid sample args"); | ||
println!("{sample_args:?}"); | ||
|
||
let lens = devices | ||
.map(|devices| { | ||
Regex::new(r"\d+") | ||
.unwrap() | ||
.find_iter(&devices) | ||
.map(|c| c.as_str().parse().unwrap()) | ||
.collect() | ||
}) | ||
.unwrap_or_else(|| vec![1]); | ||
let dist = lens.iter().sum(); | ||
println!("distribution: {lens:?}"); | ||
|
||
let (seeds, senders) = WorkerSeed::new(InprocNode::new(lens.len())); | ||
let barrier = Arc::new(Barrier::new(dist + 1)); | ||
thread::scope(|s| { | ||
let _workers = zip(lens, seeds) | ||
.enumerate() | ||
.scan(0, |start, (id, (len, seed))| { | ||
let dist = Distribution::new(*start, len, dist); | ||
*start += len; | ||
|
||
let meta = model.meta.distribute(dist); | ||
let model = &model; | ||
let barrier = barrier.clone(); | ||
Some(s.spawn(move || { | ||
let WorkerSeed { node, tasks } = seed; | ||
let weights = Weights::new(model, dist); | ||
let mut worker = Worker::new(id, &node, meta.clone(), weights); | ||
let mut cache = meta.kv_cache(meta.nctx).map(Blob::new); | ||
let sin_cos = <Operators as minicpm3::Operators>::build_sin_cos( | ||
meta.dt_embd, | ||
meta.nctx, | ||
meta.dh, | ||
&ThisThread, | ||
); | ||
|
||
let sample = RandomSample::new(&node); | ||
let indices = RandomSample::build_indices(model.meta.nvoc, &ThisThread); | ||
let mut pair = KVPair::new(0, f16::ZERO); | ||
let mut pairs = Tensor::kv_pair_vec(1, |_| unsafe { | ||
from_raw_parts_mut(&mut pair as *mut _ as *mut u8, size_of_val(&pair)) | ||
}); | ||
|
||
barrier.wait(); | ||
for task in tasks { | ||
let Task { | ||
nt, | ||
pos, | ||
embd, | ||
next, | ||
} = task; | ||
let mut embd = meta.embd(nt).map(|size| { | ||
let mut blob = Blob::new(size); | ||
unsafe { copy_nonoverlapping(embd, blob.as_mut_ptr(), size) }; | ||
blob | ||
}); | ||
let mut logits = meta.logits(if id == 0 { 1 } else { 0 }).map(Blob::new); | ||
worker | ||
.launch( | ||
minicpm3::MiniCPM3Args { | ||
embd: embd.map_slice_mut(), | ||
logits: logits.map_slice_mut(), | ||
sin_cos: sin_cos.map_slice(), | ||
requests: vec![MiniCPM3Request { | ||
cache: cache.map_slice_mut(), | ||
seq_len: nt, | ||
out_len: if id == 0 { 1 } else { 0 }, | ||
pos, | ||
}], | ||
num_tokens: nt, | ||
max_seq_len: nt, | ||
max_att_len: nt + pos, | ||
}, | ||
&mut [], | ||
&ThisThread, | ||
) | ||
.unwrap(); | ||
if id == 0 { | ||
sample | ||
.launch( | ||
&mut pairs, | ||
&logits, | ||
&indices, | ||
sample_args, | ||
&mut [], | ||
&ThisThread, | ||
) | ||
.unwrap(); | ||
next.send(pair.idx() as _).unwrap() | ||
} | ||
} | ||
})) | ||
}) | ||
.collect::<Vec<_>>(); | ||
|
||
let senders = senders.into_boxed_slice(); | ||
barrier.wait(); | ||
test_infer_paralle( | ||
senders, | ||
test_utils::AboutToken { | ||
tokenizer, | ||
token_embd: model.token_embd, | ||
nvoc: model.meta.nvoc, | ||
eos, | ||
}, | ||
&prompt, | ||
max_steps, | ||
) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
use common::{Contiguous, Distribution}; | ||
use minicpm3::{MiniCPM3BlkStorage, MiniCPM3BlkWeight, MiniCPM3Storage, Tensor, WeightLoader}; | ||
use operators::{ | ||
all_reduce::{AllReduce, NonAllReduce}, | ||
common_cpu::Cpu, | ||
random_sample::common_cpu::Operator as RandomSampleCpu, | ||
rearrange::common_cpu::Operator as Rearrange, | ||
Blob, ByteOf, QueueOf, TopoNode, | ||
}; | ||
use std::{marker::PhantomData, ops::Deref}; | ||
|
||
pub struct Operators<N = Cpu, R = NonAllReduce<Cpu, Rearrange>>(PhantomData<(N, R)>); | ||
|
||
pub type RandomSample = minicpm3::RandomSample<Cpu, RandomSampleCpu>; | ||
|
||
pub struct Weights<'w> { | ||
blks: Box<[MiniCPM3BlkStorage<Contiguous<'w, Blob>>]>, | ||
output_norm: &'w [u8], | ||
output: &'w [u8], | ||
long_factor: &'w [u8], | ||
sort_factor: &'w [u8], | ||
} | ||
|
||
macro_rules! op { | ||
($name:ident) => { | ||
operators::$name::common_cpu::Operator | ||
}; | ||
} | ||
|
||
impl<N, R> minicpm3::Operators for Operators<N, R> | ||
where | ||
N: TopoNode<Cpu>, | ||
R: AllReduce<Cpu, N>, | ||
{ | ||
type Hardware = Cpu; | ||
type TopoNode = N; | ||
type Rope = op!(rope); | ||
type Attention = op!(attention); | ||
type RmsNorm = op!(rms_norm); | ||
type Add = op!(add); | ||
type MatMul = op!(mat_mul); | ||
type Swiglu = op!(swiglu); | ||
type Rearrange = op!(rearrange); | ||
type Scale = op!(scale); | ||
type AttnKVCached = op!(attention_kv_cached); | ||
type AllReduce = R; | ||
|
||
fn debug<T>(tensor: &Tensor<T>, _queue: &QueueOf<Self::Hardware>) | ||
where | ||
T: Deref<Target = [ByteOf<Self::Hardware>]>, | ||
{ | ||
println!("{tensor}") | ||
} | ||
} | ||
|
||
impl<'w> Weights<'w> { | ||
pub fn new(model: &'w MiniCPM3Storage<&'w [u8]>, dist: Distribution) -> Self { | ||
let MiniCPM3Storage { | ||
meta, | ||
output_norm, | ||
output, | ||
blocks, | ||
rope_long, | ||
rope_short, | ||
.. | ||
} = model; | ||
|
||
let blks = blocks | ||
.iter() | ||
.map(|blk| { | ||
blk.clone() | ||
.into_vec() | ||
.into_iter() | ||
.map(|(which, data)| { | ||
(which, meta.distribute_data(which, data, dist, Blob::new)) | ||
}) | ||
.collect::<MiniCPM3BlkStorage<_>>() | ||
}) | ||
.collect(); | ||
|
||
Self { | ||
blks, | ||
output_norm, | ||
output, | ||
long_factor: rope_long, | ||
sort_factor: rope_short, | ||
} | ||
} | ||
} | ||
|
||
impl WeightLoader for Weights<'_> { | ||
type Hardware = Cpu; | ||
type Weight<'s> | ||
= &'s [u8] | ||
where | ||
Self: 's; | ||
|
||
#[inline] | ||
fn load_blk( | ||
&self, | ||
which: MiniCPM3BlkWeight, | ||
iblk: usize, | ||
_queue: &QueueOf<Self::Hardware>, | ||
) -> Self::Weight<'_> { | ||
let MiniCPM3BlkStorage { | ||
attn_norm, | ||
attn_qb, | ||
attn_qa, | ||
attn_kvb, | ||
attn_kva, | ||
attn_qa_norm, | ||
attn_kva_norm, | ||
attn_o, | ||
ffn_norm, | ||
ffn_gate_up, | ||
ffn_down, | ||
ffn_gate, | ||
ffn_up, | ||
} = &self.blks[iblk]; | ||
use MiniCPM3BlkWeight as W; | ||
match which { | ||
W::AttnNorm => attn_norm, | ||
W::AttnQB => attn_qb, | ||
W::AttnQA => attn_qa, | ||
W::AttnKvB => attn_kvb, | ||
W::AttnKvA => attn_kva, | ||
W::AttnQANorm => attn_qa_norm, | ||
W::AttnKvANorm => attn_kva_norm, | ||
W::AttnO => attn_o, | ||
W::FfnNorm => ffn_norm, | ||
W::FfnGateUp => ffn_gate_up, | ||
W::FfnDown => ffn_down, | ||
W::FfnGate => ffn_gate, | ||
W::FfnUp => ffn_up, | ||
} | ||
} | ||
|
||
#[inline] | ||
fn output_norm(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Weight<'_> { | ||
self.output_norm | ||
} | ||
|
||
#[inline] | ||
fn output(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Weight<'_> { | ||
self.output | ||
} | ||
#[inline] | ||
fn long_factor<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'_> { | ||
self.long_factor | ||
} | ||
#[inline] | ||
fn short_factor<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'_> { | ||
self.sort_factor | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod infer; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[package] | ||
name = "minicpm3" | ||
version = "0.0.0" | ||
edition = "2021" | ||
authors = ["onenewcode <[email protected]>", "YdrMaster <[email protected]>"] | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
common.workspace = true | ||
gguf.workspace = true | ||
tensor.workspace = true | ||
operators.workspace = true | ||
itertools.workspace = true | ||
half = "2.4" | ||
|
||
[dev-dependencies] | ||
test-utils.workspace = true |
Oops, something went wrong.