Skip to content

Commit

Permalink
perf(llama): 优化分布式切分和参数加载
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 21, 2025
1 parent 8729b3d commit 816101f
Show file tree
Hide file tree
Showing 9 changed files with 364 additions and 450 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ members = [

"models/llama/common",
"models/llama/common-cpu",
"models/llama/opencl",
"models/llama/infini",
# "models/llama/opencl",
# "models/llama/infini",
"models/llama/cuda",

"models/clip/common",
Expand All @@ -34,7 +34,7 @@ itertools = "0.13"
env_logger = "0.11"
build-script-cfg = "0.0"

operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "df027a4", default-features = false }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "fd8f972", default-features = false }

search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "9b6289d" }
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "f40bcb5" }
Expand Down
46 changes: 45 additions & 1 deletion common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::{borrow::Borrow, collections::HashMap, hash::Hash, ops::Deref};
use std::{
borrow::Borrow,
collections::HashMap,
hash::Hash,
ops::{Deref, Range},
};

pub enum Contiguous<'a, T> {
Borrowed(&'a [u8]),
Expand Down Expand Up @@ -52,3 +57,42 @@ impl<K: Eq + Hash, V> Slab<K, V> {
self.0.entry(key).or_default().push(value);
}
}

#[derive(Clone, Copy, Debug)]
pub struct Distribution {
pub start: usize,
pub len: usize,
pub total: usize,
}

impl Distribution {
pub const MONO: Self = Self {
start: 0,
len: 1,
total: 1,
};
}

pub struct WeightMemCalculator {
align: usize,
size: usize,
}

impl WeightMemCalculator {
#[inline]
pub const fn new(align: usize) -> Self {
Self { align, size: 0 }
}

#[inline]
pub const fn size(&self) -> usize {
self.size
}

#[inline]
pub fn push(&mut self, size: usize) -> Range<usize> {
let start = self.size.div_ceil(self.align) * self.align;
self.size = start + size;
start..self.size
}
}
15 changes: 9 additions & 6 deletions models/llama/common-cpu/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{Operators, RandomSample, Weights};
use common::Distribution;
use gguf::GGufModel;
use llama::{ext::ggml_quants::f16, LlamaRequest, LlamaStorage, LlamaWorker, Tensor};
use operators::{
Expand Down Expand Up @@ -59,16 +60,18 @@ fn test_infer() {
let _workers = zip(lens, seeds)
.enumerate()
.scan(0, |start, (id, (len, seed))| {
let range = *start..*start + len;
*start = range.end;

let mut meta = model.meta.clone();
meta.distribute(range.clone(), count);
let dist = Distribution {
start: *start,
len,
total: count,
};
*start += len;

let meta = model.meta.distribute(dist);
let model = &model;
Some(s.spawn(move || {
let WorkerSeed { node, tasks } = seed;
let weights = Weights::new(model, range, count);
let weights = Weights::new(model, dist);
let mut worker = Worker::new(id, &node, meta.clone(), weights);
let mut cache = meta.kv_cache(meta.nctx).map(Blob::new);
let sin_cos = <Operators as llama::Operators>::build_sin_cos(
Expand Down
42 changes: 22 additions & 20 deletions models/llama/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use common::Contiguous;
use common::{Contiguous, Distribution};
use llama::{
ext::ggml_quants::{self, digit_layout::DigitLayout, f16, DataBlock, QuantExt},
BlkWeight, LlamaBlkStorage, LlamaStorage, Tensor,
LlamaBlkStorage, LlamaBlkWeight, LlamaStorage, Tensor,
TensorUsage::Computation,
WeightLoader,
};
Expand All @@ -16,7 +16,7 @@ use std::{
cell::{Ref, RefCell},
marker::PhantomData,
mem::size_of,
ops::{Deref, Range, RangeBounds},
ops::{Deref, Range},
ptr::copy_nonoverlapping,
slice::{from_raw_parts, from_raw_parts_mut},
};
Expand All @@ -41,7 +41,7 @@ pub struct Weights<'w> {

pub struct WeightCache {
cache: Blob,
cached_weight: BlkWeight,
cached_weight: LlamaBlkWeight,
cached_weight_iblk: usize,
}

Expand Down Expand Up @@ -85,11 +85,7 @@ where
}

impl<'w> Weights<'w> {
pub fn new(
model: &'w LlamaStorage<&'w [u8]>,
range: impl RangeBounds<usize> + Clone,
count: usize,
) -> Self {
pub fn new(model: &'w LlamaStorage<&'w [u8]>, dist: Distribution) -> Self {
let LlamaStorage {
meta,
output_norm,
Expand All @@ -100,11 +96,17 @@ impl<'w> Weights<'w> {

let blks = blocks
.iter()
.map(|blk| blk.distribute(meta, range.clone(), count, Blob::new))
.map(|blk| {
blk.into_vec()
.into_iter()
.map(|(which, data)| {
(which, meta.distribute_data(which, data, dist, Blob::new))
})
.collect::<LlamaBlkStorage<_>>()
})
.collect::<Box<_>>();

let mut meta = meta.clone();
meta.distribute(range.clone(), count);
let meta = meta.distribute(dist);
let size_qkv = meta.attn_qkv(Computation).take();
let size_o = meta.attn_o(Computation).take();
let size_gate_up = meta.ffn_gate_up(Computation).take();
Expand All @@ -113,7 +115,7 @@ impl<'w> Weights<'w> {
let weight_cache = if meta.dt_embd == meta.dt_linear {
RefCell::new(WeightCache {
cache: Blob::new(0),
cached_weight: BlkWeight::AttnQKV,
cached_weight: LlamaBlkWeight::AttnQKV,
cached_weight_iblk: 0,
})
} else {
Expand All @@ -131,7 +133,7 @@ impl<'w> Weights<'w> {

RefCell::new(WeightCache {
cache,
cached_weight: BlkWeight::AttnQKV,
cached_weight: LlamaBlkWeight::AttnQKV,
cached_weight_iblk: 0,
})
};
Expand Down Expand Up @@ -207,7 +209,7 @@ impl WeightLoader for Weights<'_> {
#[inline]
fn load_blk(
&self,
which: BlkWeight,
which: LlamaBlkWeight,
iblk: usize,
_queue: &QueueOf<Self::Hardware>,
) -> Self::Weight<'_> {
Expand All @@ -233,10 +235,10 @@ impl WeightLoader for Weights<'_> {
ffn_down,
} = &blks[iblk];

use BlkWeight::{
use Dequant::{Borrowed, Cached};
use LlamaBlkWeight::{
AttnNorm, AttnO, AttnQKV, AttnQKVBias, FfnDown, FfnGateInp, FfnGateUp, FfnNorm,
};
use Dequant::{Borrowed, Cached};

#[rustfmt::skip]
match which {
Expand Down Expand Up @@ -301,7 +303,7 @@ impl WeightLoader for Weights<'_> {

fn load_moe<'a>(
&'a self,
which: BlkWeight,
which: LlamaBlkWeight,
iblk: usize,
iexp: usize,
_queue: &'a QueueOf<Self::Hardware>,
Expand All @@ -315,8 +317,8 @@ impl WeightLoader for Weights<'_> {
} = self;
assert_eq!(dt_embd, dt_mat);
let w = match which {
BlkWeight::FfnGateUp => &*blks[iblk].ffn_gate_up,
BlkWeight::FfnDown => &*blks[iblk].ffn_down,
LlamaBlkWeight::FfnGateUp => &*blks[iblk].ffn_gate_up,
LlamaBlkWeight::FfnDown => &*blks[iblk].ffn_down,
_ => unreachable!(),
};
let one = w.len() / nexp;
Expand Down
38 changes: 15 additions & 23 deletions models/llama/common/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{args::Args, LlamaMeta};
use super::{args::Args, LlamaBlkWeight, LlamaMeta};
use gguf::ggml_quants::{
digit_layout::{types as ty, DigitLayout},
f16,
Expand Down Expand Up @@ -53,18 +53,6 @@ pub trait Operators {
}
}

#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum BlkWeight {
AttnNorm,
AttnQKV,
AttnQKVBias,
AttnO,
FfnNorm,
FfnGateInp,
FfnGateUp,
FfnDown,
}

pub trait WeightLoader {
type Hardware: Hardware;
type Weight<'s>: Deref<Target = [ByteOf<Self::Hardware>]> + 's
Expand All @@ -73,14 +61,14 @@ pub trait WeightLoader {

fn load_blk<'a>(
&'a self,
which: BlkWeight,
which: LlamaBlkWeight,
iblk: usize,
queue: &'a QueueOf<Self::Hardware>,
) -> Self::Weight<'a>;

fn load_moe<'a>(
&'a self,
which: BlkWeight,
which: LlamaBlkWeight,
iblk: usize,
iexp: usize,
queue: &'a QueueOf<Self::Hardware>,
Expand Down Expand Up @@ -638,7 +626,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
iblk: usize,
queue: &'a QueueOf<W::Hardware>,
) -> Tensor<W::Weight<'a>> {
let w = self.weights.load_blk(BlkWeight::AttnNorm, iblk, queue);
let w = self.weights.load_blk(LlamaBlkWeight::AttnNorm, iblk, queue);
self.norm.clone().map(|_| w)
}

Expand All @@ -648,7 +636,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
iblk: usize,
queue: &'a QueueOf<W::Hardware>,
) -> Tensor<W::Weight<'a>> {
let w = self.weights.load_blk(BlkWeight::AttnQKV, iblk, queue);
let w = self.weights.load_blk(LlamaBlkWeight::AttnQKV, iblk, queue);
self.attn_qkv.clone().map(|_| w)
}

Expand All @@ -658,7 +646,9 @@ impl<W: WeightLoader> WeightDecorator<W> {
iblk: usize,
queue: &'a QueueOf<W::Hardware>,
) -> Tensor<W::Weight<'a>> {
let w = self.weights.load_blk(BlkWeight::AttnQKVBias, iblk, queue);
let w = self
.weights
.load_blk(LlamaBlkWeight::AttnQKVBias, iblk, queue);
self.attn_qkv_bias.clone().map(|_| w)
}

Expand All @@ -668,7 +658,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
iblk: usize,
queue: &'a QueueOf<W::Hardware>,
) -> Tensor<W::Weight<'a>> {
let w = self.weights.load_blk(BlkWeight::AttnO, iblk, queue);
let w = self.weights.load_blk(LlamaBlkWeight::AttnO, iblk, queue);
self.attn_o.clone().map(|_| w)
}

Expand All @@ -678,7 +668,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
iblk: usize,
queue: &'a QueueOf<W::Hardware>,
) -> Tensor<W::Weight<'a>> {
let w = self.weights.load_blk(BlkWeight::FfnNorm, iblk, queue);
let w = self.weights.load_blk(LlamaBlkWeight::FfnNorm, iblk, queue);
self.norm.clone().map(|_| w)
}

Expand All @@ -688,7 +678,9 @@ impl<W: WeightLoader> WeightDecorator<W> {
iblk: usize,
queue: &'a QueueOf<W::Hardware>,
) -> Tensor<W::Weight<'a>> {
let w = self.weights.load_blk(BlkWeight::FfnGateInp, iblk, queue);
let w = self
.weights
.load_blk(LlamaBlkWeight::FfnGateInp, iblk, queue);
self.ffn_gate_inp.clone().map(|_| w)
}

Expand All @@ -699,7 +691,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
iexp: usize,
queue: &'a QueueOf<W::Hardware>,
) -> Tensor<W::Weight<'a>> {
const WHICH: BlkWeight = BlkWeight::FfnGateUp;
const WHICH: LlamaBlkWeight = LlamaBlkWeight::FfnGateUp;
let w = if self.is_moe {
self.weights.load_moe(WHICH, iblk, iexp, queue)
} else {
Expand All @@ -715,7 +707,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
iexp: usize,
queue: &'a QueueOf<W::Hardware>,
) -> Tensor<W::Weight<'a>> {
const WHICH: BlkWeight = BlkWeight::FfnDown;
const WHICH: LlamaBlkWeight = LlamaBlkWeight::FfnDown;
let w = if self.is_moe {
self.weights.load_moe(WHICH, iblk, iexp, queue)
} else {
Expand Down
Loading

0 comments on commit 816101f

Please sign in to comment.