Skip to content

Commit

Permalink
refactor: 推理接入算子库的 reform
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jul 5, 2024
1 parent b2c33b1 commit f917ff5
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 54 deletions.
39 changes: 29 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ tokio = { version = "1.38", features = ["rt-multi-thread", "sync"] }
digit-layout = "0.0"
build-script-cfg = "0.0"

operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "189c2d5", default-features = false }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "3807deb", default-features = false }
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "877df52" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "877df52" }

cndrv = "0.1"
search-neuware-tools = "0.0"
3 changes: 1 addition & 2 deletions devices/cambricon-mlu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ authors = ["YdrMaster <[email protected]>"]
common = { path = "../../common" }
common-devices = { path = "../common" }
tensor = { path = "../../tensor" }
cndrv.workspace = true
# operators = { workspace = true, features = ["cambricon-mlu"] }
operators = { workspace = true, features = ["cambricon-mlu"] }

[build-dependencies]
build-script-cfg.workspace = true
Expand Down
3 changes: 1 addition & 2 deletions devices/cambricon-mlu/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#![cfg(detected_neuware)]

pub extern crate cndrv;

pub use operators::cndrv;
pub use tensor::Tensor;

pub fn synchronize() {
Expand Down
20 changes: 10 additions & 10 deletions devices/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use common::utok;
use common_devices::{Operators, SliceOn};
use operators::{
fuesd_softmax::common_cpu as softmax, mat_mul::common_cpu as mat_mul,
rms_norm::common_cpu as rms_norm, rope::common_cpu as rope, swiglu::common_cpu as swiglu,
Operator, QueueOf,
reform::common_cpu as reform, rms_norm::common_cpu as rms_norm, rope::common_cpu as rope,
swiglu::common_cpu as swiglu, Operator, QueueOf,
};
use std::ops::{Deref, DerefMut};
use tensor::Tensor;
Expand All @@ -23,6 +23,7 @@ pub use common_devices::{Kernels, KernelsA, KernelsB};
pub use operators::common_cpu::{Handle as Cpu, ThisThread};

pub struct CpuKernels {
reform: reform::Operator,
mat_mul: mat_mul::Operator,
rms_norm: rms_norm::Operator,
rope: rope::Operator,
Expand All @@ -33,6 +34,7 @@ pub struct CpuKernels {
impl Default for CpuKernels {
fn default() -> Self {
Self {
reform: reform::Operator::new(&Cpu),
mat_mul: mat_mul::Operator::new(&Cpu),
rms_norm: rms_norm::Operator::new(&Cpu),
rope: rope::Operator::new(&Cpu),
Expand All @@ -47,6 +49,12 @@ impl Kernels<Cpu> for CpuKernels {}
impl Operators for CpuKernels {
type Handle = Cpu;

fn reform_op(
&self,
_: &QueueOf<Self::Handle>,
) -> &impl operators::reform::Reform<Self::Handle> {
&self.reform
}
fn rms_norm_op(
&self,
_: &QueueOf<Self::Handle>,
Expand Down Expand Up @@ -92,12 +100,4 @@ impl KernelsB for CpuKernels {
{
gather::gather(x, table, tokens);
}

fn reform<T, U>(&self, dst: &mut Tensor<T>, src: &Tensor<U>, _queue: &QueueOf<Self::Handle>)
where
T: DerefMut<Target = SliceOn<Self::Handle>>,
U: Deref<Target = SliceOn<Self::Handle>>,
{
src.reform_to(dst);
}
}
31 changes: 25 additions & 6 deletions devices/common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use common::utok;
use operators::{
fuesd_softmax, mat_mul, rms_norm, rope, swiglu, Argument, Handle, Operator, QueueOf,
fuesd_softmax, mat_mul, reform, rms_norm, rope, swiglu, Argument, Handle, Operator, QueueOf,
TensorLayout,
};
use std::ops::{Deref, DerefMut};
Expand All @@ -26,6 +26,7 @@ pub type SliceOn<H> = [<H as Handle>::Byte];
pub trait Operators {
type Handle: Handle;

fn reform_op(&self, queue: &QueueOf<Self::Handle>) -> &impl reform::Reform<Self::Handle>;
fn rms_norm_op(&self, queue: &QueueOf<Self::Handle>) -> &impl rms_norm::RmsNorm<Self::Handle>;
fn mat_mul_op(&self, queue: &QueueOf<Self::Handle>) -> &impl mat_mul::MatMul<Self::Handle>;
fn rope_op(&self, queue: &QueueOf<Self::Handle>) -> &impl rope::Rope<Self::Handle>;
Expand All @@ -39,6 +40,11 @@ pub trait Operators {
pub trait KernelsA {
type Handle: Handle;

fn reform<T, U>(&self, dst: &mut Tensor<T>, src: &Tensor<U>, queue: &QueueOf<Self::Handle>)
where
T: DerefMut<Target = SliceOn<Self::Handle>>,
U: Deref<Target = SliceOn<Self::Handle>>;

fn rms_norm<T, U, V>(
&self,
y: &mut Tensor<T>,
Expand Down Expand Up @@ -97,18 +103,31 @@ pub trait KernelsB {
T: DerefMut<Target = SliceOn<Self::Handle>>,
U: Deref<Target = [u8]>,
I: IntoIterator<Item = utok>;

fn reform<T, U>(&self, dst: &mut Tensor<T>, src: &Tensor<U>, queue: &QueueOf<Self::Handle>)
where
T: DerefMut<Target = SliceOn<Self::Handle>>,
U: Deref<Target = SliceOn<Self::Handle>>;
}

pub trait Kernels<H: Handle>: KernelsA<Handle = H> + KernelsB<Handle = H> {}

impl<Ops: Operators> KernelsA for Ops {
type Handle = <Ops as Operators>::Handle;

fn reform<T, U>(&self, dst: &mut Tensor<T>, src: &Tensor<U>, queue: &QueueOf<Self::Handle>)
where
T: DerefMut<Target = SliceOn<Self::Handle>>,
U: Deref<Target = SliceOn<Self::Handle>>,
{
self.reform_op(queue)
.launch(
&reform::Args {
dst_layout: layout(dst),
dst_base: dst.base_mut(),
src_layout: layout(src),
src_base: src.base(),
},
queue,
)
.unwrap();
}

fn rms_norm<T, U, V>(
&self,
y: &mut Tensor<T>,
Expand Down
28 changes: 8 additions & 20 deletions devices/nvidia-gpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod gather;
mod sample;

use common::utok;
use common_devices::{layout, Operators, SliceOn};
use common_devices::{Operators, SliceOn};
use cuda::{AsRaw, Device};
use digit_layout::types::{F16, U32};
use operators::{
Expand Down Expand Up @@ -127,6 +127,13 @@ impl Kernels<Gpu> for NvidiaKernels {}
impl Operators for NvidiaKernels {
type Handle = Gpu;

fn reform_op(
&self,
queue: &QueueOf<Self::Handle>,
) -> &impl operators::reform::Reform<Self::Handle> {
&self.get(queue).reform
}

fn rms_norm_op(
&self,
queue: &QueueOf<Self::Handle>,
Expand Down Expand Up @@ -176,25 +183,6 @@ impl KernelsB for NvidiaKernels {
{
gather::gather(x, table, tokens, queue);
}

fn reform<T, U>(&self, dst: &mut Tensor<T>, src: &Tensor<U>, queue: &QueueOf<Self::Handle>)
where
T: DerefMut<Target = SliceOn<Self::Handle>>,
U: Deref<Target = SliceOn<Self::Handle>>,
{
self.get(queue)
.reform
.launch(
&operators::reform::Args {
dst_layout: layout(dst),
dst_base: dst.base_mut(),
src_layout: layout(src),
src_base: src.base(),
},
queue,
)
.unwrap();
}
}

pub fn synchronize() {
Expand Down
2 changes: 1 addition & 1 deletion models/llama/common/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use causal_lm::QueryContext;
use common_devices::{Kernels, KernelsA, KernelsB, SliceOn};
use common_devices::{Kernels, KernelsA, SliceOn};
use itertools::izip;
use operators::{Handle, QueueOf};
use std::ops::{Deref, DerefMut};
Expand Down

0 comments on commit f917ff5

Please sign in to comment.