-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(transformer-cpu): 整理、优化 transformer cpu
Signed-off-by: YdrMaster <[email protected]>
- Loading branch information
Showing
6 changed files
with
83 additions
and
102 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,95 +1,58 @@ | ||
use std::{ | ||
cell::{Ref, RefCell, RefMut}, | ||
alloc::{alloc, Layout}, | ||
mem::align_of, | ||
ops::{Deref, DerefMut}, | ||
rc::Rc, | ||
}; | ||
use tensor::PhysicalCell; | ||
use tensor::Splitable; | ||
|
||
#[derive(Debug)] | ||
#[repr(transparent)] | ||
pub struct Unique(Vec<u8>); | ||
pub struct Storage(Rc<Internal>); | ||
|
||
impl Deref for Unique { | ||
type Target = [u8]; | ||
|
||
#[inline] | ||
fn deref(&self) -> &Self::Target { | ||
&self.0 | ||
} | ||
} | ||
|
||
impl DerefMut for Unique { | ||
#[inline] | ||
fn deref_mut(&mut self) -> &mut Self::Target { | ||
&mut self.0 | ||
} | ||
} | ||
|
||
impl Unique { | ||
impl Storage { | ||
#[inline] | ||
pub fn new(size: usize) -> Self { | ||
Self(vec![0u8; size]) | ||
const ALIGN: usize = align_of::<usize>(); | ||
let layout = Layout::from_size_align(size, ALIGN).unwrap(); | ||
Self(Rc::new(Internal { | ||
ptr: unsafe { alloc(layout) }, | ||
len: size, | ||
})) | ||
} | ||
} | ||
|
||
#[derive(Clone, Debug)] | ||
pub struct Cell(Rc<RefCell<Vec<u8>>>); | ||
pub struct VecRef<'a>(Ref<'a, Vec<u8>>); | ||
pub struct VecRefMut<'a>(RefMut<'a, Vec<u8>>); | ||
|
||
impl PhysicalCell for Cell { | ||
type Raw = [u8]; | ||
type Access<'a> = VecRef<'a>; | ||
type AccessMut<'a> = VecRefMut<'a>; | ||
|
||
#[inline] | ||
unsafe fn get_unchecked(&self) -> &Self::Raw { | ||
&*self.0.as_ptr() | ||
} | ||
|
||
#[inline] | ||
unsafe fn get_unchecked_mut(&mut self) -> &mut Self::Raw { | ||
&mut *self.0.as_ptr() | ||
} | ||
|
||
impl Splitable for Storage { | ||
#[inline] | ||
fn access(&self) -> Self::Access<'_> { | ||
VecRef(self.0.borrow()) | ||
} | ||
|
||
#[inline] | ||
fn access_mut(&mut self) -> Self::AccessMut<'_> { | ||
VecRefMut(self.0.borrow_mut()) | ||
fn split(&self) -> Self { | ||
Self(self.0.clone()) | ||
} | ||
} | ||
|
||
impl Deref for VecRef<'_> { | ||
impl Deref for Storage { | ||
type Target = [u8]; | ||
|
||
#[inline] | ||
fn deref(&self) -> &Self::Target { | ||
&self.0 | ||
unsafe { std::slice::from_raw_parts(self.0.ptr, self.0.len) } | ||
} | ||
} | ||
|
||
impl Deref for VecRefMut<'_> { | ||
type Target = [u8]; | ||
|
||
impl DerefMut for Storage { | ||
#[inline] | ||
fn deref(&self) -> &Self::Target { | ||
&self.0 | ||
fn deref_mut(&mut self) -> &mut Self::Target { | ||
unsafe { std::slice::from_raw_parts_mut(self.0.ptr, self.0.len) } | ||
} | ||
} | ||
|
||
impl DerefMut for VecRefMut<'_> { | ||
#[inline] | ||
fn deref_mut(&mut self) -> &mut Self::Target { | ||
&mut self.0 | ||
} | ||
struct Internal { | ||
ptr: *mut u8, | ||
len: usize, | ||
} | ||
|
||
impl Cell { | ||
pub fn new(size: usize) -> Self { | ||
Self(Rc::new(RefCell::new(vec![0; size]))) | ||
impl Drop for Internal { | ||
#[inline] | ||
fn drop(&mut self) { | ||
const ALIGN: usize = align_of::<usize>(); | ||
let layout = Layout::from_size_align(self.len, ALIGN).unwrap(); | ||
unsafe { std::alloc::dealloc(self.ptr, layout) } | ||
} | ||
} |