Skip to content

Commit

Permalink
refactor(tensor): 张量 reform 基于算子库的算法实现
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jul 8, 2024
1 parent ca4cab0 commit c2d7128
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 102 deletions.
62 changes: 41 additions & 21 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ itertools = "0.13"
serde = "1.0"
serde_json = "1.0"
memmap2 = "0.9"
rayon = "1.10"
tokio = { version = "1.38", features = ["rt-multi-thread", "sync"] }
digit-layout = "0.0"
build-script-cfg = "0.0"

operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "3807deb", default-features = false }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "3e9b113", default-features = false }
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "877df52" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "877df52" }
search-neuware-tools = "0.0"
44 changes: 14 additions & 30 deletions devices/common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,10 @@
use common::utok;
use operators::{
fuesd_softmax, mat_mul, reform, rms_norm, rope, swiglu, Argument, Handle, Operator, QueueOf,
TensorLayout,
fuesd_softmax, mat_mul, reform, rms_norm, rope, swiglu, Handle, Operator, QueueOf,
};
use std::ops::{Deref, DerefMut};
use tensor::Tensor;

pub fn layout<T>(t: &Tensor<T>) -> TensorLayout {
let dt = t.data_layout();
let shape = t
.shape()
.iter()
.map(|&x| Argument::new(x as usize))
.collect::<Vec<_>>();
let strides = t
.strides()
.iter()
.map(|&x| Argument::new(x as isize * dt.nbytes() as isize))
.collect::<Vec<_>>();
TensorLayout::new(dt, shape, strides)
}

pub type SliceOn<H> = [<H as Handle>::Byte];

pub trait Operators {
Expand Down Expand Up @@ -118,9 +102,9 @@ impl<Ops: Operators> KernelsA for Ops {
self.reform_op(queue)
.launch(
&reform::Args {
dst_layout: layout(dst),
dst_layout: dst.layout(),
dst_base: dst.base_mut(),
src_layout: layout(src),
src_layout: src.layout(),
src_base: src.base(),
},
queue,
Expand All @@ -143,11 +127,11 @@ impl<Ops: Operators> KernelsA for Ops {
self.rms_norm_op(queue)
.launch(
&rms_norm::Args {
y_layout: layout(y),
y_layout: y.layout(),
y_base: y.base_mut(),
x_layout: layout(x),
x_layout: x.layout(),
x_base: x.base(),
w_layout: layout(w),
w_layout: w.layout(),
w_base: w.base(),
epsilon,
},
Expand All @@ -169,9 +153,9 @@ impl<Ops: Operators> KernelsA for Ops {
self.rope_op(queue)
.launch(
&rope::Args {
t_layout: layout(t),
t_layout: t.layout(),
t_base: t.base_mut(),
p_layout: layout(pos),
p_layout: pos.layout(),
p_base: pos.base(),
theta,
},
Expand All @@ -196,12 +180,12 @@ impl<Ops: Operators> KernelsA for Ops {
self.mat_mul_op(queue)
.launch(
&mat_mul::Args {
c_layout: layout(c),
c_layout: c.layout(),
c_base: c.base_mut(),
beta,
a_layout: layout(a),
a_layout: a.layout(),
a_base: a.base(),
b_layout: layout(b),
b_layout: b.layout(),
b_base: b.base(),
alpha,
},
Expand All @@ -217,7 +201,7 @@ impl<Ops: Operators> KernelsA for Ops {
self.softmax_op(queue)
.launch(
&fuesd_softmax::Args {
att_layout: layout(att),
att_layout: att.layout(),
att_base: att.base_mut(),
},
queue,
Expand All @@ -233,9 +217,9 @@ impl<Ops: Operators> KernelsA for Ops {
self.swiglu_op(queue)
.launch(
&swiglu::Args {
gate_layout: layout(gate),
gate_layout: gate.layout(),
gate_base: gate.base_mut(),
up_layout: layout(up),
up_layout: up.layout(),
up_base: up.base(),
},
queue,
Expand Down
2 changes: 1 addition & 1 deletion models/llama/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ itertools.workspace = true
digit-layout.workspace = true
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
rayon.workspace = true
operators.workspace = true
rayon = "1.10"
2 changes: 1 addition & 1 deletion models/llama/common/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use digit_layout::{
types::{BF16, F16, F32},
AsDigit, DigitLayout,
};
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
use tensor::Tensor;

impl Storage {
Expand Down Expand Up @@ -48,7 +49,6 @@ fn typed<T: AsDigit + Sync, U: AsDigit + Send>(
src: Tensor<Weight>,
cast: impl Fn(&T) -> U + Sync,
) -> Tensor<Weight> {
use rayon::iter::*;
use tensor::{reslice, reslice_mut};

assert_eq!(src.data_layout(), T::LAYOUT);
Expand Down
2 changes: 1 addition & 1 deletion tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ authors = ["YdrMaster <[email protected]>"]
smallvec = "1.13"
nalgebra = "0.32"
digit-layout.workspace = true
operators = { workspace = true, features = ["common-cpu"] }
half.workspace = true
rayon.workspace = true
serde.workspace = true
2 changes: 0 additions & 2 deletions tensor/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
mod broadcast;
mod compatibility;
mod fmt;
mod pattern;
mod reshape;
Expand All @@ -14,7 +13,6 @@ pub type udim = u32;
#[allow(non_camel_case_types)]
pub type idim = i32;

pub use compatibility::Compatibility;
pub use nalgebra::DVector;
pub use pattern::{expand_indices, idx_strides, Affine, Shape};
pub use slice::SliceDim;
Expand Down
Loading

0 comments on commit c2d7128

Please sign in to comment.