Skip to content

Commit

Permalink
refactor(tensor): 现在 Tensor 被视作一个 Physical 的容器,以方便变换 Physical
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 14, 2024
1 parent 619b507 commit 08528ab
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 34 deletions.
2 changes: 1 addition & 1 deletion service/src/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl CpuTask {
info!("prefill transformer ... {:?}", time.elapsed());

loop {
let token = argmax(reslice::<u8, f16>(logits.access().as_slice()));
let token = argmax(reslice::<u8, f16>(logits.as_slice()));
if token == eos {
break;
}
Expand Down
30 changes: 6 additions & 24 deletions tensor/src/physical_cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,23 @@ pub trait PhysicalCell {
impl<Physical: PhysicalCell> Tensor<Physical> {
#[inline]
pub unsafe fn access_unchecked(&self) -> Tensor<&Physical::Raw> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: self.physical.get_unchecked(),
}
self.as_ref()
.map_physical(|physical| physical.get_unchecked())
}

#[inline]
pub unsafe fn access_unchecked_mut(&mut self) -> Tensor<&mut Physical::Raw> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: self.physical.get_unchecked_mut(),
}
self.as_mut()
.map_physical(|physical| physical.get_unchecked_mut())
}

#[inline]
pub fn access(&self) -> Tensor<Physical::Access<'_>> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: self.physical.access(),
}
unsafe { self.as_ref().map_physical(|physical| physical.access()) }
}

#[inline]
pub fn access_mut(&mut self) -> Tensor<Physical::AccessMut<'_>> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: self.physical.access_mut(),
}
unsafe { self.as_mut().map_physical(|physical| physical.access_mut()) }
}
}
28 changes: 24 additions & 4 deletions tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,36 @@ impl<Physical> Tensor<Physical> {
.count()
}

#[inline]
pub fn as_ref(&self) -> Tensor<&Physical> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: &self.physical,
}
}

#[inline]
pub fn as_mut(&mut self) -> Tensor<&mut Physical> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: &mut self.physical,
}
}

/// # Safety
///
/// The caller must ensure that the new `physical` matches data_type, shape and pattern of `self`.
#[inline]
pub unsafe fn map_physical<U>(&self, f: impl FnOnce(&Physical) -> U) -> Tensor<U> {
pub unsafe fn map_physical<U>(self, f: impl FnOnce(Physical) -> U) -> Tensor<U> {
Tensor {
data_type: self.data_type,
shape: self.shape.clone(),
pattern: self.pattern.clone(),
physical: f(&self.physical),
shape: self.shape,
pattern: self.pattern,
physical: f(self.physical),
}
}
}
Expand Down
11 changes: 8 additions & 3 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl Transformer {
self.0.eos_token_id()
}

pub fn decode<Id>(&mut self, mut requests: Vec<Request<Id>>) -> (Vec<Id>, Tensor<Storage>) {
pub fn decode<Id>(&mut self, mut requests: Vec<Request<Id>>) -> (Vec<Id>, Tensor<Vec<u8>>) {
requests.sort_unstable_by_key(|t| t.tokens.len());

// println!("tokens:");
Expand Down Expand Up @@ -211,7 +211,11 @@ impl Transformer {
slice![from begin, take len]
};

let mut logits = tensor(dt, &[tokens.len, voc]);
let mut logits = Tensor::new(
dt,
&[tokens.len, voc],
vec![0u8; (tokens.len * voc) as usize * dt.size()],
);
let mut x = x0.slice(&[tokens, slice![all]]);
// println!("decode slice:\n{}", x.access());

Expand All @@ -225,13 +229,14 @@ impl Transformer {
// println!("model norm:\n{}", x.access());

let lm_head = self.0.lm_head().transpose(&[1, 0]);
mat_mul(&mut logits.access_mut(), 0., &x.access(), &lm_head, 1.);
mat_mul(&mut logits, 0., &x.access(), &lm_head, 1.);
// println!("logits:\n{}", logits.access());

(requests.into_iter().map(|r| r.id).collect(), logits)
}
}

#[inline]
fn tensor(dt: DataType, shape: &[udim]) -> Tensor<Storage> {
Tensor::new(
dt,
Expand Down
8 changes: 6 additions & 2 deletions transformer-nvidia/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,11 @@ impl<'ctx> Transformer<'ctx> {
// compute.synchronize();
// println!("model norm:\n{}", map_tensor(&x));

let mut logits = unsafe { logits_dev.map_physical(|dev| vec![0; dev.access().len()]) };
let mut logits = unsafe {
logits_dev
.as_ref()
.map_physical(|dev| vec![0; dev.access().len()])
};
mat_mul(
&self.cublas,
&logits_dev.access(),
Expand Down Expand Up @@ -373,7 +377,7 @@ fn tensor<'ctx>(dt: DataType, shape: &[udim], stream: &Stream<'ctx>) -> Tensor<S
#[allow(unused)]
fn map_tensor(tensor: &Tensor<Storage>) -> Tensor<Vec<u8>> {
unsafe {
tensor.map_physical(|dev| {
tensor.as_ref().map_physical(|dev| {
let dev = dev.access();
let mut buf = vec![0; dev.len()];
dev.copy_out(&mut buf);
Expand Down
2 changes: 2 additions & 0 deletions transformer-nvidia/src/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ impl<'ctx> ModelParameters<'ctx> {
($param:ident) => {
unsafe {
host.$param()
.as_ref()
.map_physical(|slice| stream.from_host(slice).into())
}
};
Expand Down Expand Up @@ -88,6 +89,7 @@ impl<'ctx> LayerParameter<'ctx> {
($param:ident) => {
unsafe {
host.$param(layer)
.as_ref()
.map_physical(|slice| stream.from_host(slice).into())
}
};
Expand Down

0 comments on commit 08528ab

Please sign in to comment.