Skip to content

Commit

Permalink
refactor(tensor): reform 直接拷贝只需要二者非 1 维度 stride 相同
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 1, 2024
1 parent 57a9265 commit 9aa5ce5
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 93 deletions.
4 changes: 3 additions & 1 deletion tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod broadcast;
mod data_type;
mod fmt;
mod pattern;
mod physical_cell;
mod reshape;
mod slice;
mod split;
Expand All @@ -17,8 +18,9 @@ pub type idim = i32;
pub use data_type::DataType;
pub use nalgebra::DVector;
pub use pattern::{expand_indices, idx_strides, Affine, Shape};
pub use physical_cell::PhysicalCell;
pub use slice::SliceDim;
pub use tensor::{Storage, Tensor};
pub use tensor::Tensor;

use std::mem::{align_of, size_of, size_of_val};

Expand Down
59 changes: 59 additions & 0 deletions tensor/src/physical_cell.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use crate::Tensor;
use std::ops::{Deref, DerefMut};

pub trait PhysicalCell {
type Raw: ?Sized;
type Access<'a>: Deref<Target = Self::Raw>
where
Self: 'a;
type AccessMut<'a>: DerefMut<Target = Self::Raw>
where
Self: 'a;

unsafe fn get_unchecked(&self) -> &Self::Raw;
unsafe fn get_unchecked_mut(&mut self) -> &mut Self::Raw;
fn access(&self) -> Self::Access<'_>;
fn access_mut(&mut self) -> Self::AccessMut<'_>;
}

impl<Physical: PhysicalCell> Tensor<Physical> {
#[inline]
pub unsafe fn access_unchecked(&self) -> Tensor<&Physical::Raw> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: self.physical.get_unchecked(),
}
}

#[inline]
pub unsafe fn access_unchecked_mut(&mut self) -> Tensor<&mut Physical::Raw> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: self.physical.get_unchecked_mut(),
}
}

#[inline]
pub fn access(&self) -> Tensor<Physical::Access<'_>> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: self.physical.access(),
}
}

#[inline]
pub fn access_mut(&mut self) -> Tensor<Physical::AccessMut<'_>> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: self.physical.access_mut(),
}
}
}
163 changes: 72 additions & 91 deletions tensor/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use crate::{expand_indices, idim, idx_strides, pattern::Pattern, udim, DataType, Shape};
use nalgebra::{DVector, DVectorView};
use rayon::iter::*;
use std::ops::{Deref, DerefMut};
use std::{
iter::zip,
ops::{Deref, DerefMut},
panic,
};

#[derive(Clone, Debug)]
pub struct Tensor<Physical> {
Expand Down Expand Up @@ -120,66 +124,48 @@ impl<Physical> Tensor<Physical> {
physical: f(&self.physical),
}
}

#[inline]
fn byte_offset(&self) -> usize {
self.pattern.offset() as usize * self.data_type.size()
}
}

pub trait Storage {
type Raw: ?Sized;
type Access<'a>: Deref<Target = Self::Raw>
where
Self: 'a;
type AccessMut<'a>: DerefMut<Target = Self::Raw>
where
Self: 'a;

unsafe fn get_unchecked(&self) -> &Self::Raw;
unsafe fn get_unchecked_mut(&mut self) -> &mut Self::Raw;
fn access(&self) -> Self::Access<'_>;
fn access_mut(&mut self) -> Self::AccessMut<'_>;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[repr(u8)]
pub enum Compatibility {
Same,
Squeeze,
Reform,
None,
}

impl<Physical: Storage> Tensor<Physical> {
#[inline]
pub unsafe fn access_unchecked(&self) -> Tensor<&Physical::Raw> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: self.physical.get_unchecked(),
impl Compatibility {
pub fn between<T, U>(a: &Tensor<T>, b: &Tensor<U>) -> Self {
if a.data_type != b.data_type {
return Self::None;
}
}

#[inline]
pub unsafe fn access_unchecked_mut(&mut self) -> Tensor<&mut Physical::Raw> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: self.physical.get_unchecked_mut(),
}
}

#[inline]
pub fn access(&self) -> Tensor<Physical::Access<'_>> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: self.physical.access(),
let mut actual_a = zip(&a.shape, a.pattern.0.as_slice()).filter(|(&d, _)| d > 1);
let mut actual_b = zip(&b.shape, b.pattern.0.as_slice()).filter(|(&d, _)| d > 1);
let mut squeeze = true;
loop {
match (actual_a.next(), actual_b.next()) {
(Some((da, sa)), Some((db, sb))) => {
if da != db {
return Self::None;
}
if sa != sb {
squeeze = false;
}
}
(Some(_), None) | (None, Some(_)) => return Self::None,
(None, None) => break,
}
}
}

#[inline]
pub fn access_mut(&mut self) -> Tensor<Physical::AccessMut<'_>> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: self.physical.access_mut(),
if squeeze {
if a.shape == b.shape {
Self::Same
} else {
Self::Squeeze
}
} else {
Self::Reform
}
}
}
Expand All @@ -188,33 +174,26 @@ impl<Physical: Deref<Target = [u8]>> Tensor<Physical> {
#[inline]
pub fn as_slice(&self) -> &[u8] {
debug_assert!(self.is_contiguous());
let off = self.byte_offset();
let off = self.bytes_offset();
let len = self.bytes_size();
&self.physical[off..][..len]
&self.physical[off as usize..][..len]
}

pub fn locate_start(&self) -> *const u8 {
let off = self.byte_offset();
(&self.physical[off]) as _
let off = self.bytes_offset();
(&self.physical[off as usize]) as _
}

pub fn locate(&self, indices: &DVectorView<idim>) -> Option<*const u8> {
let i = self.pattern.0.dot(indices) as usize * self.data_type.size();
self.physical.get(i).map(|r| r as _)
}

#[inline]
pub fn as_ptr(&self) -> *const u8 {
let ptr = self.physical.as_ptr();
let offset = self.byte_offset();
unsafe { ptr.add(offset) }
}

/// # Safety
///
/// The caller must ensure that the `dst` can be a valid tensor physical.
pub unsafe fn reform_to_raw(&self, dst: &mut [u8]) {
let src = &self.physical[self.byte_offset()..];
let src = &self.physical[self.bytes_offset() as usize..];
// 计算结尾连续维度数量
let contiguous = self.contiguous_len();
if contiguous == self.shape.len() {
Expand All @@ -240,27 +219,29 @@ impl<Physical: Deref<Target = [u8]>> Tensor<Physical> {
where
U: DerefMut<Target = [u8]>,
{
assert_eq!(self.data_type, dst.data_type);
assert_eq!(self.shape, dst.shape);
let contiguous = self.contiguous_len().min(dst.contiguous_len());
if contiguous == self.shape.len() {
dst.as_mut_slice().copy_from_slice(self.as_slice());
} else {
let dt = self.data_type.size();
// 一部分维度连续,迭代不连续的部分
let (iter, contiguous) = self.shape.split_at(self.shape.len() - contiguous);
let (n, idx_strides) = idx_strides(iter);
let src_pattern = self.pattern.0.view_range(..iter.len(), ..);
let dst_pattern = dst.pattern.0.view_range(..iter.len(), ..);
let src = self.locate_start() as usize;
let dst = dst.locate_start() as usize;
let count = contiguous.iter().product::<udim>() as usize * dt;
(0..n).into_par_iter().for_each(|i| {
let indices = expand_indices(i, &idx_strides, &[]);
let src = (src + src_pattern.dot(&indices) as usize * dt) as *const u8;
let dst = (dst + dst_pattern.dot(&indices) as usize * dt) as *mut u8;
unsafe { std::ptr::copy_nonoverlapping(src, dst, count) };
});
match Compatibility::between(self, dst) {
Compatibility::Same | Compatibility::Squeeze => {
dst.as_mut_slice().copy_from_slice(self.as_slice());
}
Compatibility::Reform => {
let contiguous = self.contiguous_len().min(dst.contiguous_len());
let dt = self.data_type.size();
// 一部分维度连续,迭代不连续的部分
let (iter, contiguous) = self.shape.split_at(self.shape.len() - contiguous);
let (n, idx_strides) = idx_strides(iter);
let src_pattern = self.pattern.0.view_range(..iter.len(), ..);
let dst_pattern = dst.pattern.0.view_range(..iter.len(), ..);
let src = self.locate_start() as usize;
let dst = dst.locate_start() as usize;
let count = contiguous.iter().product::<udim>() as usize * dt;
(0..n).into_par_iter().for_each(|i| {
let indices = expand_indices(i, &idx_strides, &[]);
let src = (src + src_pattern.dot(&indices) as usize * dt) as *const u8;
let dst = (dst + dst_pattern.dot(&indices) as usize * dt) as *mut u8;
unsafe { std::ptr::copy_nonoverlapping(src, dst, count) };
});
}
Compatibility::None => panic!("Incompatible tensors"),
}
}
}
Expand All @@ -269,14 +250,14 @@ impl<Physical: DerefMut<Target = [u8]>> Tensor<Physical> {
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [u8] {
debug_assert!(self.is_contiguous());
let off = self.byte_offset();
let off = self.bytes_offset();
let len = self.bytes_size();
&mut self.physical[off..][..len]
&mut self.physical[off as usize..][..len]
}

pub fn locate_start_mut(&mut self) -> *mut u8 {
let off = self.byte_offset();
(&mut self.physical[off]) as _
let off = self.bytes_offset();
(&mut self.physical[off as usize]) as _
}

pub fn locate_mut(&mut self, indices: &DVectorView<idim>) -> Option<*mut u8> {
Expand Down
3 changes: 2 additions & 1 deletion transformer-cpu/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
ops::{Deref, DerefMut},
rc::Rc,
};
use tensor::PhysicalCell;

#[derive(Clone, Debug)]
pub struct Storage(Rc<RefCell<Vec<u8>>>);
pub struct VecRef<'a>(Ref<'a, Vec<u8>>);
pub struct VecRefMut<'a>(RefMut<'a, Vec<u8>>);

impl tensor::Storage for Storage {
impl PhysicalCell for Storage {
type Raw = [u8];
type Access<'a> = VecRef<'a>;
type AccessMut<'a> = VecRefMut<'a>;
Expand Down

0 comments on commit 9aa5ce5

Please sign in to comment.