Skip to content

Commit

Permalink
refactor(gpt2): 残差连接替换为基本相加算子
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Dec 30, 2024
1 parent 152266b commit 2f059ce
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 41 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ itertools = "0.13"
env_logger = "0.11"
build-script-cfg = "0.0"

operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "78b578d", default-features = false }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "2e8ec0e", default-features = false }

search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "9b6289d" }
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "f40bcb5" }
Expand Down
1 change: 0 additions & 1 deletion models/gpt2/common-cpu/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ fn test_infer() {
embd: embd.map_slice_mut(),
logits: logits.map_slice_mut(),
idx: postion(input.len(), pos).map_slice(),
idx_add: postion(input.len(), 0).map_slice(),
requests: vec![gpt2::args::Request {
cache: cache.map_slice_mut(),
seq_len: input.len(),
Expand Down
1 change: 1 addition & 0 deletions models/gpt2/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ where
type MatMul = op!(mat_mul);
type AttnKVCached = op!(attention_kv_cached);
type Gelu = op!(gelu);
type Add = op!(add);
type Rearrange = op!(rearrange);
type AllReduce = R;

Expand Down
1 change: 0 additions & 1 deletion models/gpt2/common/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ pub struct Args<'a, H: Hardware> {
/// shape: [nout, nvoc]
pub logits: Tensor<&'a mut [H::Byte]>,
pub idx: Tensor<&'a [H::Byte]>,
pub idx_add: Tensor<&'a [H::Byte]>,
pub requests: Vec<Request<'a, H>>,

pub max_seq_len: usize,
Expand Down
109 changes: 71 additions & 38 deletions models/gpt2/common/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::{args::Args, Gpt2Meta};
use itertools::izip;
use operators::{
add::{self, Add},
add_rows::{self, AddRows},
all_reduce::{self, AllReduce, ReduceOp},
attention_kv_cached::{self, AttnKVCached},
Expand All @@ -21,6 +22,7 @@ pub trait Operators {
type MatMul: MatMul<Self::Hardware>;
type AttnKVCached: AttnKVCached<Self::Hardware>;
type Gelu: Gelu<Self::Hardware>;
type Add: Add<Self::Hardware>;
type Rearrange: Rearrange<Self::Hardware>;
type AllReduce: AllReduce<Self::Hardware, Self::TopoNode>;

Expand Down Expand Up @@ -65,6 +67,7 @@ pub struct Gpt2Worker<Ops: Operators, W> {
mat_mul: Ops::MatMul,
attn_kv_cached: Ops::AttnKVCached,
gelu: Ops::Gelu,
add: Ops::Add,
rearrange: Ops::Rearrange,
all_reduce: Ops::AllReduce,
pub debug: bool,
Expand All @@ -81,6 +84,7 @@ impl<Ops: Operators, W> Gpt2Worker<Ops, W> {
mat_mul: Ops::MatMul::new(processor),
attn_kv_cached: Ops::AttnKVCached::new(processor),
gelu: Ops::Gelu::new(processor),
add: Ops::Add::new(processor),
rearrange: Ops::Rearrange::new(processor),
all_reduce: Ops::AllReduce::new(node),
debug: true,
Expand Down Expand Up @@ -132,7 +136,6 @@ where
max_seq_len,
max_att_len,
idx,
idx_add,
} = args;
let Gpt2Meta {
nblk,
Expand Down Expand Up @@ -213,25 +216,27 @@ where
let [w, b] = self.weights.attn_o(iblk, queue);
self.mat_mul(&mut x1, &o, (w, Some(b)), workspace, queue_alloc)?
}
self.add_rows(&mut x1, &x, &idx_add, workspace, queue_alloc)?;
self.all_reduce(&mut x1, workspace, queue_alloc)?;
let inplace = unsafe { x.map_slice_static() };
self.add(&mut x, &inplace, &x1, workspace, queue_alloc)?;
self.all_reduce(&mut x, workspace, queue_alloc)?;

let wb = self.weights.ffn_norm(iblk, queue);
self.layer_norm(&mut x, &x1, wb, workspace, queue_alloc)?;
self.layer_norm(&mut x1, &x, wb, workspace, queue_alloc)?;
{
let (buf, workspace) = workspace.split_at_mut(*up.get());
let mut up = up.clone().map(|_| buf);

let [w, b] = self.weights.ffn_up(iblk, queue);
self.mat_mul(&mut up, &x, (w, Some(b)), workspace, queue_alloc)?;
self.mat_mul(&mut up, &x1, (w, Some(b)), workspace, queue_alloc)?;

self.gelu(&mut up, workspace, queue_alloc)?;

let [w, b] = self.weights.ffn_down(iblk, queue);
self.mat_mul(&mut x, &up, (w, Some(b)), workspace, queue_alloc)?
self.mat_mul(&mut x1, &up, (w, Some(b)), workspace, queue_alloc)?
}
self.add_rows(&mut x, &x1, &idx_add, workspace, queue_alloc)?;
self.all_reduce(&mut x1, workspace, queue_alloc)?
let inplace = unsafe { x.map_slice_static() };
self.add(&mut x, &inplace, &x1, workspace, queue_alloc)?;
self.all_reduce(&mut x, workspace, queue_alloc)?
}
if logits.shape()[0] == 0 {
return Ok(());
Expand Down Expand Up @@ -271,6 +276,36 @@ where
Ops: Operators,
W: WeightLoader<Hardware = Ops::Hardware>,
{
fn add_rows<Dst, Src, Idx, QA>(
&self,
dst: &mut Tensor<Dst>,
src: &Tensor<Src>,
idx: &Tensor<Idx>,
workspace: &mut [ByteOf<Ops::Hardware>],
queue_alloc: &QA,
) -> Result<(), LaunchError>
where
Dst: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
Src: Deref<Target = [ByteOf<Ops::Hardware>]>,
Idx: Deref<Target = [ByteOf<Ops::Hardware>]>,
QA: QueueAlloc<Hardware = Ops::Hardware>,
{
let n = dst.shape()[0];
let mut dst = dst.map_slice_mut().tile(0, &[1, n]);
self.add_rows.launch(
&add_rows::Args {
dst_layout: dst.layout(),
dst_base: dst.base_mut(),
src_layout: src.layout(),
src_base: src.base(),
idx_layout: idx.layout(),
idx_base: idx.base(),
},
workspace,
queue_alloc,
)
}

fn layer_norm<Y, X, WB, QA>(
&self,
y: &mut Tensor<Y>,
Expand Down Expand Up @@ -402,6 +437,34 @@ where
)
}

fn add<C, A, B, QA>(
&self,
c: &mut Tensor<C>,
a: &Tensor<A>,
b: &Tensor<B>,
workspace: &mut [ByteOf<Ops::Hardware>],
queue_alloc: &QA,
) -> Result<(), LaunchError>
where
C: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
A: Deref<Target = [ByteOf<Ops::Hardware>]>,
B: Deref<Target = [ByteOf<Ops::Hardware>]>,
QA: QueueAlloc<Hardware = Ops::Hardware>,
{
self.add.launch(
&add::Args {
c_layout: c.layout(),
c_base: c.base_mut(),
a_layout: a.layout(),
a_base: a.base(),
b_layout: b.layout(),
b_base: b.base(),
},
workspace,
queue_alloc,
)
}

fn rearrange<Y, X, QA>(
&self,
dst: &mut Tensor<Y>,
Expand Down Expand Up @@ -450,36 +513,6 @@ where
queue_alloc,
)
}

fn add_rows<Dst, Src, Idx, QA>(
&self,
dst: &mut Tensor<Dst>,
src: &Tensor<Src>,
idx: &Tensor<Idx>,
workspace: &mut [ByteOf<Ops::Hardware>],
queue_alloc: &QA,
) -> Result<(), LaunchError>
where
Dst: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
Src: Deref<Target = [ByteOf<Ops::Hardware>]>,
Idx: Deref<Target = [ByteOf<Ops::Hardware>]>,
QA: QueueAlloc<Hardware = Ops::Hardware>,
{
let n = dst.shape()[0];
let mut dst = dst.map_slice_mut().tile(0, &[1, n]);
self.add_rows.launch(
&add_rows::Args {
dst_layout: dst.layout(),
dst_base: dst.base_mut(),
src_layout: src.layout(),
src_base: src.base(),
idx_layout: idx.layout(),
idx_base: idx.base(),
},
workspace,
queue_alloc,
)
}
}

struct WeightDecorator<W> {
Expand Down

0 comments on commit 2f059ce

Please sign in to comment.