diff --git a/tensor/src/lib.rs b/tensor/src/lib.rs index 6736c3b8..cb6be893 100644 --- a/tensor/src/lib.rs +++ b/tensor/src/lib.rs @@ -2,6 +2,7 @@ mod broadcast; mod data_type; mod fmt; mod pattern; +mod physical_cell; mod reshape; mod slice; mod split; @@ -17,8 +18,9 @@ pub type idim = i32; pub use data_type::DataType; pub use nalgebra::DVector; pub use pattern::{expand_indices, idx_strides, Affine, Shape}; +pub use physical_cell::PhysicalCell; pub use slice::SliceDim; -pub use tensor::{Storage, Tensor}; +pub use tensor::Tensor; use std::mem::{align_of, size_of, size_of_val}; diff --git a/tensor/src/physical_cell.rs b/tensor/src/physical_cell.rs new file mode 100644 index 00000000..9530b4b7 --- /dev/null +++ b/tensor/src/physical_cell.rs @@ -0,0 +1,59 @@ +use crate::Tensor; +use std::ops::{Deref, DerefMut}; + +pub trait PhysicalCell { + type Raw: ?Sized; + type Access<'a>: Deref + where + Self: 'a; + type AccessMut<'a>: DerefMut + where + Self: 'a; + + unsafe fn get_unchecked(&self) -> &Self::Raw; + unsafe fn get_unchecked_mut(&mut self) -> &mut Self::Raw; + fn access(&self) -> Self::Access<'_>; + fn access_mut(&mut self) -> Self::AccessMut<'_>; +} + +impl Tensor { + #[inline] + pub unsafe fn access_unchecked(&self) -> Tensor<&Physical::Raw> { + Tensor { + data_type: self.data_type, + shape: self.shape.clone(), + pattern: self.pattern.clone(), + physical: self.physical.get_unchecked(), + } + } + + #[inline] + pub unsafe fn access_unchecked_mut(&mut self) -> Tensor<&mut Physical::Raw> { + Tensor { + data_type: self.data_type, + shape: self.shape.clone(), + pattern: self.pattern.clone(), + physical: self.physical.get_unchecked_mut(), + } + } + + #[inline] + pub fn access(&self) -> Tensor> { + Tensor { + data_type: self.data_type, + shape: self.shape.clone(), + pattern: self.pattern.clone(), + physical: self.physical.access(), + } + } + + #[inline] + pub fn access_mut(&mut self) -> Tensor> { + Tensor { + data_type: self.data_type, + shape: self.shape.clone(), + pattern: self.pattern.clone(), + physical: self.physical.access_mut(), + } + } +} diff --git a/tensor/src/tensor.rs b/tensor/src/tensor.rs index 7b668d9c..845c094d 100644 --- a/tensor/src/tensor.rs +++ b/tensor/src/tensor.rs @@ -1,7 +1,11 @@ use crate::{expand_indices, idim, idx_strides, pattern::Pattern, udim, DataType, Shape}; use nalgebra::{DVector, DVectorView}; use rayon::iter::*; -use std::ops::{Deref, DerefMut}; +use std::{ + iter::zip, + ops::{Deref, DerefMut}, + panic, +}; #[derive(Clone, Debug)] pub struct Tensor { @@ -120,66 +124,48 @@ impl Tensor { physical: f(&self.physical), } } - - #[inline] - fn byte_offset(&self) -> usize { - self.pattern.offset() as usize * self.data_type.size() - } } -pub trait Storage { - type Raw: ?Sized; - type Access<'a>: Deref - where - Self: 'a; - type AccessMut<'a>: DerefMut - where - Self: 'a; - - unsafe fn get_unchecked(&self) -> &Self::Raw; - unsafe fn get_unchecked_mut(&mut self) -> &mut Self::Raw; - fn access(&self) -> Self::Access<'_>; - fn access_mut(&mut self) -> Self::AccessMut<'_>; +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[repr(u8)] +pub enum Compatibility { + Same, + Squeeze, + Reform, + None, } -impl Tensor { - #[inline] - pub unsafe fn access_unchecked(&self) -> Tensor<&Physical::Raw> { - Tensor { - data_type: self.data_type, - shape: self.shape.clone(), - pattern: self.pattern.clone(), - physical: self.physical.get_unchecked(), +impl Compatibility { + pub fn between(a: &Tensor, b: &Tensor) -> Self { + if a.data_type != b.data_type { + return Self::None; } - } - #[inline] - pub unsafe fn access_unchecked_mut(&mut self) -> Tensor<&mut Physical::Raw> { - Tensor { - data_type: self.data_type, - shape: self.shape.clone(), - pattern: self.pattern.clone(), - physical: self.physical.get_unchecked_mut(), - } - } - - #[inline] - pub fn access(&self) -> Tensor> { - Tensor { - data_type: self.data_type, - shape: self.shape.clone(), - pattern: self.pattern.clone(), - physical: self.physical.access(), + let mut actual_a = zip(&a.shape, a.pattern.0.as_slice()).filter(|(&d, _)| d > 1); + let mut actual_b = zip(&b.shape, b.pattern.0.as_slice()).filter(|(&d, _)| d > 1); + let mut squeeze = true; + loop { + match (actual_a.next(), actual_b.next()) { + (Some((da, sa)), Some((db, sb))) => { + if da != db { + return Self::None; + } + if sa != sb { + squeeze = false; + } + } + (Some(_), None) | (None, Some(_)) => return Self::None, + (None, None) => break, + } } - } - - #[inline] - pub fn access_mut(&mut self) -> Tensor> { - Tensor { - data_type: self.data_type, - shape: self.shape.clone(), - pattern: self.pattern.clone(), - physical: self.physical.access_mut(), + if squeeze { + if a.shape == b.shape { + Self::Same + } else { + Self::Squeeze + } + } else { + Self::Reform } } } @@ -188,14 +174,14 @@ impl> Tensor { #[inline] pub fn as_slice(&self) -> &[u8] { debug_assert!(self.is_contiguous()); - let off = self.byte_offset(); + let off = self.bytes_offset(); let len = self.bytes_size(); - &self.physical[off..][..len] + &self.physical[off as usize..][..len] } pub fn locate_start(&self) -> *const u8 { - let off = self.byte_offset(); - (&self.physical[off]) as _ + let off = self.bytes_offset(); + (&self.physical[off as usize]) as _ } pub fn locate(&self, indices: &DVectorView) -> Option<*const u8> { @@ -203,18 +189,11 @@ impl> Tensor { self.physical.get(i).map(|r| r as _) } - #[inline] - pub fn as_ptr(&self) -> *const u8 { - let ptr = self.physical.as_ptr(); - let offset = self.byte_offset(); - unsafe { ptr.add(offset) } - } - /// # Safety /// /// The caller must ensure that the `dst` can be a valid tensor physical. pub unsafe fn reform_to_raw(&self, dst: &mut [u8]) { - let src = &self.physical[self.byte_offset()..]; + let src = &self.physical[self.bytes_offset() as usize..]; // 计算结尾连续维度数量 let contiguous = self.contiguous_len(); if contiguous == self.shape.len() { @@ -240,27 +219,29 @@ impl> Tensor { where U: DerefMut, { - assert_eq!(self.data_type, dst.data_type); - assert_eq!(self.shape, dst.shape); - let contiguous = self.contiguous_len().min(dst.contiguous_len()); - if contiguous == self.shape.len() { - dst.as_mut_slice().copy_from_slice(self.as_slice()); - } else { - let dt = self.data_type.size(); - // 一部分维度连续,迭代不连续的部分 - let (iter, contiguous) = self.shape.split_at(self.shape.len() - contiguous); - let (n, idx_strides) = idx_strides(iter); - let src_pattern = self.pattern.0.view_range(..iter.len(), ..); - let dst_pattern = dst.pattern.0.view_range(..iter.len(), ..); - let src = self.locate_start() as usize; - let dst = dst.locate_start() as usize; - let count = contiguous.iter().product::() as usize * dt; - (0..n).into_par_iter().for_each(|i| { - let indices = expand_indices(i, &idx_strides, &[]); - let src = (src + src_pattern.dot(&indices) as usize * dt) as *const u8; - let dst = (dst + dst_pattern.dot(&indices) as usize * dt) as *mut u8; - unsafe { std::ptr::copy_nonoverlapping(src, dst, count) }; - }); + match Compatibility::between(self, dst) { + Compatibility::Same | Compatibility::Squeeze => { + dst.as_mut_slice().copy_from_slice(self.as_slice()); + } + Compatibility::Reform => { + let contiguous = self.contiguous_len().min(dst.contiguous_len()); + let dt = self.data_type.size(); + // 一部分维度连续,迭代不连续的部分 + let (iter, contiguous) = self.shape.split_at(self.shape.len() - contiguous); + let (n, idx_strides) = idx_strides(iter); + let src_pattern = self.pattern.0.view_range(..iter.len(), ..); + let dst_pattern = dst.pattern.0.view_range(..iter.len(), ..); + let src = self.locate_start() as usize; + let dst = dst.locate_start() as usize; + let count = contiguous.iter().product::() as usize * dt; + (0..n).into_par_iter().for_each(|i| { + let indices = expand_indices(i, &idx_strides, &[]); + let src = (src + src_pattern.dot(&indices) as usize * dt) as *const u8; + let dst = (dst + dst_pattern.dot(&indices) as usize * dt) as *mut u8; + unsafe { std::ptr::copy_nonoverlapping(src, dst, count) }; + }); + } + Compatibility::None => panic!("Incompatible tensors"), } } } @@ -269,14 +250,14 @@ impl> Tensor { #[inline] pub fn as_mut_slice(&mut self) -> &mut [u8] { debug_assert!(self.is_contiguous()); - let off = self.byte_offset(); + let off = self.bytes_offset(); let len = self.bytes_size(); - &mut self.physical[off..][..len] + &mut self.physical[off as usize..][..len] } pub fn locate_start_mut(&mut self) -> *mut u8 { - let off = self.byte_offset(); - (&mut self.physical[off]) as _ + let off = self.bytes_offset(); + (&mut self.physical[off as usize]) as _ } pub fn locate_mut(&mut self, indices: &DVectorView) -> Option<*mut u8> { diff --git a/transformer-cpu/src/storage.rs b/transformer-cpu/src/storage.rs index 256a06f7..cbed0834 100644 --- a/transformer-cpu/src/storage.rs +++ b/transformer-cpu/src/storage.rs @@ -3,13 +3,14 @@ ops::{Deref, DerefMut}, rc::Rc, }; +use tensor::PhysicalCell; #[derive(Clone, Debug)] pub struct Storage(Rc>>); pub struct VecRef<'a>(Ref<'a, Vec>); pub struct VecRefMut<'a>(RefMut<'a, Vec>); -impl tensor::Storage for Storage { +impl PhysicalCell for Storage { type Raw = [u8]; type Access<'a> = VecRef<'a>; type AccessMut<'a> = VecRefMut<'a>;