From c2d7128277b100f603f8804c422ddfa327d7012c Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 8 Jul 2024 17:09:14 +0800 Subject: [PATCH] =?UTF-8?q?refactor(tensor):=20=E5=BC=A0=E9=87=8F=20reform?= =?UTF-8?q?=20=E5=9F=BA=E4=BA=8E=E7=AE=97=E5=AD=90=E5=BA=93=E7=9A=84?= =?UTF-8?q?=E7=AE=97=E6=B3=95=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- Cargo.lock | 62 ++++++++++++------ Cargo.toml | 3 +- devices/common/src/lib.rs | 44 ++++--------- models/llama/common/Cargo.toml | 2 +- models/llama/common/src/cast.rs | 2 +- tensor/Cargo.toml | 2 +- tensor/src/lib.rs | 2 - tensor/src/tensor.rs | 110 +++++++++++++++++++------------- 8 files changed, 125 insertions(+), 102 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f31cf4e7..cc652c36 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -280,25 +280,25 @@ checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" [[package]] name = "cndrv" -version = "0.1.1" -source = "git+https://github.com/InfiniTensor/cndrv?rev=d354eee#d354eee4f37bc64ac3d56358bb6263316982d29e" +version = "0.1.2" +source = "git+https://github.com/InfiniTensor/cndrv?rev=1dab7ce#1dab7ce97522974f3065ae9de6b1351d8a37feb9" dependencies = [ "bindgen", "build-script-cfg", "log", - "search-neuware-tools 0.0.0 (git+https://github.com/InfiniTensor/cndrv?rev=d354eee)", + "search-neuware-tools 0.0.0 (git+https://github.com/InfiniTensor/cndrv?rev=1dab7ce)", ] [[package]] name = "cnnl" version = "0.0.0" -source = "git+https://github.com/InfiniTensor/cndrv?rev=d354eee#d354eee4f37bc64ac3d56358bb6263316982d29e" +source = "git+https://github.com/InfiniTensor/cndrv?rev=1dab7ce#1dab7ce97522974f3065ae9de6b1351d8a37feb9" dependencies = [ "bindgen", "build-script-cfg", "cndrv", "digit-layout", - "search-neuware-tools 0.0.0 (git+https://github.com/InfiniTensor/cndrv?rev=d354eee)", + "search-neuware-tools 0.0.0 (git+https://github.com/InfiniTensor/cndrv?rev=1dab7ce)", ] [[package]] @@ -331,7 +331,7 @@ dependencies = [ [[package]] name = "common" version = "0.1.0" -source = "git+https://github.com/YdrMaster/operators-rs?rev=3807deb#3807deb37e74122bb8b86015e8f20c6513593b74" +source = "git+https://github.com/YdrMaster/operators-rs?rev=3e9b113#3e9b1136efa2e167ad09b2bfe24f6ab186c97a64" dependencies = [ "digit-layout", ] @@ -381,7 +381,7 @@ dependencies = [ "operators", "rand", "sample", - "search-cuda-tools", + "search-cuda-tools 0.1.0 (git+https://github.com/YdrMaster/cuda-driver?rev=877df52)", "tensor", ] @@ -419,13 +419,25 @@ checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" [[package]] name = "cublas" version = "0.1.0" -source = "git+https://github.com/YdrMaster/cuda-driver?rev=877df52#877df5295ee9ab0a24737346349e55fe08940b7f" +source = "git+https://github.com/YdrMaster/cuda-driver?rev=0befb35#0befb351182c18ac36f76aab0c6ba3398b8eeee3" dependencies = [ "bindgen", "build-script-cfg", - "cuda", + "cuda 0.1.0 (git+https://github.com/YdrMaster/cuda-driver?rev=0befb35)", "half", - "search-cuda-tools", + "search-cuda-tools 0.1.0 (git+https://github.com/YdrMaster/cuda-driver?rev=0befb35)", +] + +[[package]] +name = "cuda" +version = "0.1.0" +source = "git+https://github.com/YdrMaster/cuda-driver?rev=0befb35#0befb351182c18ac36f76aab0c6ba3398b8eeee3" +dependencies = [ + "bindgen", + "build-script-cfg", + "half", + "log", + "search-cuda-tools 0.1.0 (git+https://github.com/YdrMaster/cuda-driver?rev=0befb35)", ] [[package]] @@ -437,7 +449,7 @@ dependencies = [ "build-script-cfg", "half", "log", - "search-cuda-tools", + "search-cuda-tools 0.1.0 (git+https://github.com/YdrMaster/cuda-driver?rev=877df52)", ] [[package]] @@ -934,7 +946,7 @@ dependencies = [ "itertools 0.13.0", "llama", "log", - "search-cuda-tools", + "search-cuda-tools 0.1.0 (git+https://github.com/YdrMaster/cuda-driver?rev=877df52)", ] [[package]] @@ -950,7 +962,7 @@ dependencies = [ "llama", "log", "nccl", - "search-cuda-tools", + "search-cuda-tools 0.1.0 (git+https://github.com/YdrMaster/cuda-driver?rev=877df52)", "simple_logger", ] @@ -1078,9 +1090,9 @@ source = "git+https://github.com/YdrMaster/cuda-driver?rev=877df52#877df5295ee9a dependencies = [ "bindgen", "build-script-cfg", - "cuda", + "cuda 0.1.0 (git+https://github.com/YdrMaster/cuda-driver?rev=877df52)", "digit-layout", - "search-cuda-tools", + "search-cuda-tools 0.1.0 (git+https://github.com/YdrMaster/cuda-driver?rev=877df52)", ] [[package]] @@ -1175,20 +1187,20 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "operators" version = "0.0.0" -source = "git+https://github.com/YdrMaster/operators-rs?rev=3807deb#3807deb37e74122bb8b86015e8f20c6513593b74" +source = "git+https://github.com/YdrMaster/operators-rs?rev=3e9b113#3e9b1136efa2e167ad09b2bfe24f6ab186c97a64" dependencies = [ "build-script-cfg", "cndrv", "cnnl", "common 0.1.0", "cublas", - "cuda", + "cuda 0.1.0 (git+https://github.com/YdrMaster/cuda-driver?rev=0befb35)", "digit-layout", "gemm", "half", "log", "rayon", - "search-cuda-tools", + "search-cuda-tools 0.1.0 (git+https://github.com/YdrMaster/cuda-driver?rev=0befb35)", "search-neuware-tools 0.0.0 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -1444,6 +1456,14 @@ dependencies = [ "rand", ] +[[package]] +name = "search-cuda-tools" +version = "0.1.0" +source = "git+https://github.com/YdrMaster/cuda-driver?rev=0befb35#0befb351182c18ac36f76aab0c6ba3398b8eeee3" +dependencies = [ + "find_cuda_helper", +] + [[package]] name = "search-cuda-tools" version = "0.1.0" @@ -1461,7 +1481,7 @@ checksum = "c83ae237201de85c7ece8ad034bc749a6f229f42f07f62a30008b30a3f1e2231" [[package]] name = "search-neuware-tools" version = "0.0.0" -source = "git+https://github.com/InfiniTensor/cndrv?rev=d354eee#d354eee4f37bc64ac3d56358bb6263316982d29e" +source = "git+https://github.com/InfiniTensor/cndrv?rev=1dab7ce#1dab7ce97522974f3065ae9de6b1351d8a37feb9" [[package]] name = "seq-macro" @@ -1601,7 +1621,7 @@ dependencies = [ "digit-layout", "half", "nalgebra", - "rayon", + "operators", "serde", "smallvec", ] @@ -1938,7 +1958,7 @@ dependencies = [ "log", "mixtral", "mixtral-cpu", - "search-cuda-tools", + "search-cuda-tools 0.1.0 (git+https://github.com/YdrMaster/cuda-driver?rev=877df52)", "search-neuware-tools 0.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "service", "simple_logger", diff --git a/Cargo.toml b/Cargo.toml index cbed6e7c..15828bd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,12 +31,11 @@ itertools = "0.13" serde = "1.0" serde_json = "1.0" memmap2 = "0.9" -rayon = "1.10" tokio = { version = "1.38", features = ["rt-multi-thread", "sync"] } digit-layout = "0.0" build-script-cfg = "0.0" -operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "3807deb", default-features = false } +operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "3e9b113", default-features = false } nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "877df52" } search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "877df52" } search-neuware-tools = "0.0" diff --git a/devices/common/src/lib.rs b/devices/common/src/lib.rs index d47a3bf9..24c224bf 100644 --- a/devices/common/src/lib.rs +++ b/devices/common/src/lib.rs @@ -1,26 +1,10 @@ use common::utok; use operators::{ - fuesd_softmax, mat_mul, reform, rms_norm, rope, swiglu, Argument, Handle, Operator, QueueOf, - TensorLayout, + fuesd_softmax, mat_mul, reform, rms_norm, rope, swiglu, Handle, Operator, QueueOf, }; use std::ops::{Deref, DerefMut}; use tensor::Tensor; -pub fn layout(t: &Tensor) -> TensorLayout { - let dt = t.data_layout(); - let shape = t - .shape() - .iter() - .map(|&x| Argument::new(x as usize)) - .collect::>(); - let strides = t - .strides() - .iter() - .map(|&x| Argument::new(x as isize * dt.nbytes() as isize)) - .collect::>(); - TensorLayout::new(dt, shape, strides) -} - pub type SliceOn = [::Byte]; pub trait Operators { @@ -118,9 +102,9 @@ impl KernelsA for Ops { self.reform_op(queue) .launch( &reform::Args { - dst_layout: layout(dst), + dst_layout: dst.layout(), dst_base: dst.base_mut(), - src_layout: layout(src), + src_layout: src.layout(), src_base: src.base(), }, queue, @@ -143,11 +127,11 @@ impl KernelsA for Ops { self.rms_norm_op(queue) .launch( &rms_norm::Args { - y_layout: layout(y), + y_layout: y.layout(), y_base: y.base_mut(), - x_layout: layout(x), + x_layout: x.layout(), x_base: x.base(), - w_layout: layout(w), + w_layout: w.layout(), w_base: w.base(), epsilon, }, @@ -169,9 +153,9 @@ impl KernelsA for Ops { self.rope_op(queue) .launch( &rope::Args { - t_layout: layout(t), + t_layout: t.layout(), t_base: t.base_mut(), - p_layout: layout(pos), + p_layout: pos.layout(), p_base: pos.base(), theta, }, @@ -196,12 +180,12 @@ impl KernelsA for Ops { self.mat_mul_op(queue) .launch( &mat_mul::Args { - c_layout: layout(c), + c_layout: c.layout(), c_base: c.base_mut(), beta, - a_layout: layout(a), + a_layout: a.layout(), a_base: a.base(), - b_layout: layout(b), + b_layout: b.layout(), b_base: b.base(), alpha, }, @@ -217,7 +201,7 @@ impl KernelsA for Ops { self.softmax_op(queue) .launch( &fuesd_softmax::Args { - att_layout: layout(att), + att_layout: att.layout(), att_base: att.base_mut(), }, queue, @@ -233,9 +217,9 @@ impl KernelsA for Ops { self.swiglu_op(queue) .launch( &swiglu::Args { - gate_layout: layout(gate), + gate_layout: gate.layout(), gate_base: gate.base_mut(), - up_layout: layout(up), + up_layout: up.layout(), up_base: up.base(), }, queue, diff --git a/models/llama/common/Cargo.toml b/models/llama/common/Cargo.toml index 0ccb0985..07624c0c 100644 --- a/models/llama/common/Cargo.toml +++ b/models/llama/common/Cargo.toml @@ -15,5 +15,5 @@ itertools.workspace = true digit-layout.workspace = true serde = { workspace = true, features = ["derive"] } serde_json.workspace = true -rayon.workspace = true operators.workspace = true +rayon = "1.10" diff --git a/models/llama/common/src/cast.rs b/models/llama/common/src/cast.rs index 04d5c277..5e76639f 100644 --- a/models/llama/common/src/cast.rs +++ b/models/llama/common/src/cast.rs @@ -4,6 +4,7 @@ use digit_layout::{ types::{BF16, F16, F32}, AsDigit, DigitLayout, }; +use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; use tensor::Tensor; impl Storage { @@ -48,7 +49,6 @@ fn typed( src: Tensor, cast: impl Fn(&T) -> U + Sync, ) -> Tensor { - use rayon::iter::*; use tensor::{reslice, reslice_mut}; assert_eq!(src.data_layout(), T::LAYOUT); diff --git a/tensor/Cargo.toml b/tensor/Cargo.toml index 250237c1..2d45871a 100644 --- a/tensor/Cargo.toml +++ b/tensor/Cargo.toml @@ -10,6 +10,6 @@ authors = ["YdrMaster "] smallvec = "1.13" nalgebra = "0.32" digit-layout.workspace = true +operators = { workspace = true, features = ["common-cpu"] } half.workspace = true -rayon.workspace = true serde.workspace = true diff --git a/tensor/src/lib.rs b/tensor/src/lib.rs index 10d208b2..0f140271 100644 --- a/tensor/src/lib.rs +++ b/tensor/src/lib.rs @@ -1,5 +1,4 @@ mod broadcast; -mod compatibility; mod fmt; mod pattern; mod reshape; @@ -14,7 +13,6 @@ pub type udim = u32; #[allow(non_camel_case_types)] pub type idim = i32; -pub use compatibility::Compatibility; pub use nalgebra::DVector; pub use pattern::{expand_indices, idx_strides, Affine, Shape}; pub use slice::SliceDim; diff --git a/tensor/src/tensor.rs b/tensor/src/tensor.rs index e3af6167..02fd79bb 100644 --- a/tensor/src/tensor.rs +++ b/tensor/src/tensor.rs @@ -1,11 +1,10 @@ -use crate::{expand_indices, idim, idx_strides, pattern::Pattern, udim, Compatibility, Shape}; +use crate::{idim, pattern::Pattern, udim, Shape}; use digit_layout::DigitLayout; use nalgebra::DVector; -use rayon::iter::*; +use operators::{Argument, Operator, TensorLayout}; use std::{ mem::{align_of, size_of}, ops::{Deref, DerefMut}, - panic, }; #[derive(Clone, Debug)] @@ -161,6 +160,21 @@ impl Tensor { physical: f(self.physical), } } + + pub fn layout(&self) -> TensorLayout { + let dt = self.data_layout(); + let shape: Vec> = self + .shape + .iter() + .map(|&x| Argument::new(x as usize)) + .collect::>(); + let strides = self + .strides() + .iter() + .map(|&x| Argument::new(x as isize * dt.nbytes() as isize)) + .collect::>(); + TensorLayout::new(dt, shape, strides) + } } impl> Tensor

{ @@ -202,53 +216,61 @@ impl> Tensor { /// /// The caller must ensure that the `dst` can be a valid tensor physical. pub unsafe fn reform_to_raw(&self, dst: &mut [u8]) { - let src = &self.physical[self.bytes_offset() as usize..]; - // 计算结尾连续维度数量 - let contiguous = self.contiguous_len(); - if contiguous == self.shape.len() { - // 所有维度都连续,直接拷贝所有数据 - dst.copy_from_slice(&src[..dst.len()]); - } else { - let dt = self.layout.nbytes(); - // 一部分维度连续,迭代不连续的部分 - let (iter, contiguous) = self.shape.split_at(self.shape.len() - contiguous); - let (n, idx_strides) = idx_strides(iter); - let len = contiguous.iter().product::() as usize * dt; - let pattern = self.pattern.0.view_range(..iter.len(), ..); - let ptr = dst.as_mut_ptr() as usize; - (0..n).into_par_iter().for_each(|i| { - let j = pattern.dot(&expand_indices(i, &idx_strides, &[])); - unsafe { std::slice::from_raw_parts_mut((ptr + i as usize * len) as *mut u8, len) } - .copy_from_slice(&src[j as usize * dt..][..len]); - }); - } + assert_eq!(self.bytes_size(), dst.len()); + use operators::{ + common_cpu::{Handle as Cpu, ThisThread}, + reform::{common_cpu::Operator as Reform, Args}, + }; + Reform::new(&Cpu) + .launch( + &Args { + dst_layout: { + let shape = self + .shape + .iter() + .map(|&d| Argument::new(d as usize)) + .collect::>(); + let mut strides = self + .shape + .iter() + .rev() + .scan(self.layout.nbytes() as isize, |mul, &d| { + let s = *mul; + *mul *= d as isize; + Some(Argument::new(s)) + }) + .collect::>(); + strides.reverse(); + TensorLayout::new(self.layout, shape, strides) + }, + dst_base: dst.as_mut_ptr(), + src_layout: self.layout(), + src_base: self.base(), + }, + &ThisThread, + ) + .unwrap(); } pub fn reform_to(&self, dst: &mut Tensor) where U: DerefMut, { - match Compatibility::between(self, dst) { - Compatibility::None => panic!("Incompatible tensors"), - _ => { - let contiguous = self.contiguous_len().min(dst.contiguous_len()); - let dt = self.layout.nbytes(); - // 一部分维度连续,迭代不连续的部分 - let (iter, contiguous) = self.shape.split_at(self.shape.len() - contiguous); - let (n, idx_strides) = idx_strides(iter); - let src_pattern = self.pattern.0.view_range(..iter.len(), ..); - let dst_pattern = dst.pattern.0.view_range(..iter.len(), ..); - let src = self.base() as usize; - let dst = dst.base() as usize; - let count = contiguous.iter().product::() as usize * dt; - (0..n).into_par_iter().for_each(|i| { - let indices = expand_indices(i, &idx_strides, &[]); - let src = (src + src_pattern.dot(&indices) as usize * dt) as *const u8; - let dst = (dst + dst_pattern.dot(&indices) as usize * dt) as *mut u8; - unsafe { std::ptr::copy_nonoverlapping(src, dst, count) }; - }); - } - } + use operators::{ + common_cpu::{Handle as Cpu, ThisThread}, + reform::{common_cpu::Operator as Reform, Args}, + }; + Reform::new(&Cpu) + .launch( + &Args { + dst_layout: dst.layout(), + dst_base: dst.base_mut(), + src_layout: self.layout(), + src_base: self.base(), + }, + &ThisThread, + ) + .unwrap(); } }