From b3694b7fcee3b18b619c4b7877d861562fefdc5d Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 23 Feb 2024 16:31:27 +0800 Subject: [PATCH] =?UTF-8?q?feat(tensor):=20=E5=85=81=E8=AE=B8=20Storage=20?= =?UTF-8?q?=E7=BB=95=E8=BF=87=E5=80=9F=E7=94=A8=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- tensor/src/lib.rs | 1 + tensor/src/tensor.rs | 27 +++++++++++++++++++++++++-- transformer-cpu/src/storage.rs | 11 +++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/tensor/src/lib.rs b/tensor/src/lib.rs index 1296bf60..118f787b 100644 --- a/tensor/src/lib.rs +++ b/tensor/src/lib.rs @@ -13,6 +13,7 @@ pub type idim = i32; use std::mem::{align_of, size_of}; pub use data_type::DataType; +pub use nalgebra::DVector; pub use operator::{Operator, SliceDim}; pub use pattern::{expand_indices, idx_strides, Affine, Shape}; pub use tensor::{Storage, Tensor}; diff --git a/tensor/src/tensor.rs b/tensor/src/tensor.rs index 29663184..e9c768d5 100644 --- a/tensor/src/tensor.rs +++ b/tensor/src/tensor.rs @@ -294,18 +294,41 @@ impl Tensor { } pub trait Storage { - type Access<'a> + type Raw: ?Sized; + type Access<'a>: Deref where Self: 'a; - type AccessMut<'a> + type AccessMut<'a>: DerefMut where Self: 'a; + unsafe fn get_unchecked(&self) -> &Self::Raw; + unsafe fn get_unchecked_mut(&mut self) -> &mut Self::Raw; fn access(&self) -> Self::Access<'_>; fn access_mut(&mut self) -> Self::AccessMut<'_>; } impl Tensor { + #[inline] + pub unsafe fn access_unchecked(&self) -> Tensor<&Physical::Raw> { + Tensor { + data_type: self.data_type, + shape: self.shape.clone(), + pattern: self.pattern.clone(), + physical: self.physical.get_unchecked(), + } + } + + #[inline] + pub unsafe fn access_unchecked_mut(&mut self) -> Tensor<&mut Physical::Raw> { + Tensor { + data_type: self.data_type, + shape: self.shape.clone(), + pattern: self.pattern.clone(), + physical: self.physical.get_unchecked_mut(), + } + } + #[inline] pub fn access(&self) -> Tensor> { Tensor { diff --git a/transformer-cpu/src/storage.rs b/transformer-cpu/src/storage.rs index e1621fe6..256a06f7 100644 --- a/transformer-cpu/src/storage.rs +++ b/transformer-cpu/src/storage.rs @@ -10,9 +10,20 @@ pub struct VecRef<'a>(Ref<'a, Vec>); pub struct VecRefMut<'a>(RefMut<'a, Vec>); impl tensor::Storage for Storage { + type Raw = [u8]; type Access<'a> = VecRef<'a>; type AccessMut<'a> = VecRefMut<'a>; + #[inline] + unsafe fn get_unchecked(&self) -> &Self::Raw { + &*self.0.as_ptr() + } + + #[inline] + unsafe fn get_unchecked_mut(&mut self) -> &mut Self::Raw { + &mut *self.0.as_ptr() + } + #[inline] fn access(&self) -> Self::Access<'_> { VecRef(self.0.borrow())