Skip to content

Commit

Permalink
feat(tensor): 允许 Storage 绕过借用检查
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 23, 2024
1 parent 45b4d64 commit b3694b7
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
1 change: 1 addition & 0 deletions tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub type idim = i32;
use std::mem::{align_of, size_of};

pub use data_type::DataType;
pub use nalgebra::DVector;
pub use operator::{Operator, SliceDim};
pub use pattern::{expand_indices, idx_strides, Affine, Shape};
pub use tensor::{Storage, Tensor};
Expand Down
27 changes: 25 additions & 2 deletions tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,18 +294,41 @@ impl<Physical: Clone> Tensor<Physical> {
}

pub trait Storage {
type Access<'a>
type Raw: ?Sized;
type Access<'a>: Deref<Target = Self::Raw>
where
Self: 'a;
type AccessMut<'a>
type AccessMut<'a>: DerefMut<Target = Self::Raw>
where
Self: 'a;

unsafe fn get_unchecked(&self) -> &Self::Raw;
unsafe fn get_unchecked_mut(&mut self) -> &mut Self::Raw;
fn access(&self) -> Self::Access<'_>;
fn access_mut(&mut self) -> Self::AccessMut<'_>;
}

impl<Physical: Storage> Tensor<Physical> {
#[inline]
pub unsafe fn access_unchecked(&self) -> Tensor<&Physical::Raw> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: self.physical.get_unchecked(),
}
}

#[inline]
pub unsafe fn access_unchecked_mut(&mut self) -> Tensor<&mut Physical::Raw> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: self.physical.get_unchecked_mut(),
}
}

#[inline]
pub fn access(&self) -> Tensor<Physical::Access<'_>> {
Tensor {
Expand Down
11 changes: 11 additions & 0 deletions transformer-cpu/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,20 @@ pub struct VecRef<'a>(Ref<'a, Vec<u8>>);
pub struct VecRefMut<'a>(RefMut<'a, Vec<u8>>);

impl tensor::Storage for Storage {
type Raw = [u8];
type Access<'a> = VecRef<'a>;
type AccessMut<'a> = VecRefMut<'a>;

#[inline]
unsafe fn get_unchecked(&self) -> &Self::Raw {
&*self.0.as_ptr()
}

#[inline]
unsafe fn get_unchecked_mut(&mut self) -> &mut Self::Raw {
&mut *self.0.as_ptr()
}

#[inline]
fn access(&self) -> Self::Access<'_> {
VecRef(self.0.borrow())
Expand Down

0 comments on commit b3694b7

Please sign in to comment.