Skip to content

Commit

Permalink
添加scale算子
Browse files Browse the repository at this point in the history
  • Loading branch information
onenewcode committed Feb 14, 2025
1 parent 166d2ea commit 9785400
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ itertools = "0.13"
env_logger = "0.11"
build-script-cfg = "0.0"

operators = { git = "https://github.com/onenewcode/operators-rs", branch = "dev", default-features = false }
operators = { git = "https://github.com/onenewcode/operators-rs", rev = "f4a83f7", default-features = false }

search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "f69b160" }
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "e8362c3" }
Expand Down
1 change: 1 addition & 0 deletions models/minicpm3/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ where
type MatMul = op!(mat_mul);
type Swiglu = op!(swiglu);
type Rearrange = op!(rearrange);
type Scale = op!(scale);
type AttnKVCached = op!(attention_kv_cached);
type AllReduce = R;

Expand Down
47 changes: 35 additions & 12 deletions models/minicpm3/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use gguf::ggml_quants::digit_layout::types as ty;
use gguf::ggml_quants::digit_layout::DigitLayout;
use half::f16;
use itertools::{izip, Itertools};
use operators::scale;
use operators::scale::Scale;
use operators::{
add::{self, Add},
all_reduce::{self, AllReduce, ReduceOp},
Expand Down Expand Up @@ -30,6 +32,7 @@ pub trait Operators {
type Add: Add<Self::Hardware>;
type MatMul: MatMul<Self::Hardware>;
type Swiglu: Swiglu<Self::Hardware>;
type Scale:Scale<Self::Hardware>;
type Rearrange: Rearrange<Self::Hardware>;
type AllReduce: AllReduce<Self::Hardware, Self::TopoNode>;

Expand Down Expand Up @@ -82,6 +85,7 @@ pub struct Minicpm3Worker<Ops: Operators, W> {
rope: Ops::Rope,
rms_norm: Ops::RmsNorm,
mat_mul: Ops::MatMul,
scale:Ops::Scale,
swiglu: Ops::Swiglu,
rearrange: Ops::Rearrange,
all_reduce: Ops::AllReduce,
Expand All @@ -98,6 +102,7 @@ impl<Ops: Operators, W> Minicpm3Worker<Ops, W> {
rope: Ops::Rope::new(processor),
rms_norm: Ops::RmsNorm::new(processor),
mat_mul: Ops::MatMul::new(processor),
scale: Ops::Scale::new(processor),
swiglu: Ops::Swiglu::new(processor),
rearrange: Ops::Rearrange::new(processor),
add: Ops::Add::new(processor),
Expand Down Expand Up @@ -154,18 +159,7 @@ where
let scale_depth = 1.4f32;
// 残差连接时权重缩放
let s = scale_depth / (nblk as f32).sqrt();
fn ggml_scale(embd: *mut f16, s: f16, l: usize) {
if l == 0 {
return;
} // 如果长度为 0,则无需进行任何操作

unsafe {
let slice = std::slice::from_raw_parts_mut(embd, l);
slice.iter_mut().for_each(|x| *x *= s);
}
}

ggml_scale(x.base_mut().cast::<f16>(), f16::from_f32(scale_emb), d);


let dnope = dk - dh;
let tensor = |shape: &[usize]| Tensor::new(dt_embd, shape);
Expand All @@ -184,6 +178,10 @@ where
let mut attn = attn.map(|_| buf);

let queue = queue_alloc.queue();
// 缩放
let inplace=unsafe {
x.map_slice_static()};
self.scale(&mut x, &inplace, scale_emb, workspace, queue_alloc)?;
for iblk in 0..nblk {
// norm
let w = self.weights.attn_norm(iblk, queue);
Expand Down Expand Up @@ -619,6 +617,31 @@ where
queue_alloc,
)
}
fn scale<C, A, QA>(
&self,
c: &mut Tensor<C>,
a: &Tensor<A>,
s:f32,
workspace: &mut [ByteOf<Ops::Hardware>],
queue_alloc: &QA,
) -> Result<(), LaunchError>
where
C: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
A: Deref<Target = [ByteOf<Ops::Hardware>]>,
QA: QueueAlloc<Hardware = Ops::Hardware>,
{
self.scale.launch(
&scale::Args {
c_layout: c.layout(),
c_base: c.base_mut(),
a_layout: a.layout(),
a_base: a.base(),
s,
},
workspace,
queue_alloc,
)
}
fn all_reduce<X, QA>(
&self,
x: &mut Tensor<X>,
Expand Down

0 comments on commit 9785400

Please sign in to comment.