From ff9360b1a98c83cdcf382128e9ec18a276f61a48 Mon Sep 17 00:00:00 2001 From: YdrMaster <ydrml@hotmail.com> Date: Tue, 20 Feb 2024 17:12:38 +0800 Subject: [PATCH] =?UTF-8?q?feat(tensor):=20=E5=A6=82=E6=9E=9C=E5=BC=A0?= =?UTF-8?q?=E9=87=8F=E6=95=B0=E6=8D=AE=E5=9C=A8=E4=B8=BB=E5=AD=98=E4=B8=AD?= =?UTF-8?q?=EF=BC=8C=E5=85=81=E8=AE=B8=E7=9B=B4=E6=8E=A5=E8=AE=BF=E9=97=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster <ydrml@hotmail.com> --- model-parameters/src/lib.rs | 18 ++++++++-------- model-parameters/src/memory/cast.rs | 2 +- model-parameters/src/memory/mod.rs | 23 ++++++--------------- model-parameters/src/memory/safe_tensors.rs | 2 +- model-parameters/src/save.rs | 22 ++++++++++---------- tensor/src/tensor.rs | 22 ++++++++++++++++++-- transformer-cpu/src/lib.rs | 8 ++----- 7 files changed, 50 insertions(+), 47 deletions(-) diff --git a/model-parameters/src/lib.rs b/model-parameters/src/lib.rs index f83b8f2a..5e328a2f 100644 --- a/model-parameters/src/lib.rs +++ b/model-parameters/src/lib.rs @@ -75,17 +75,22 @@ struct ConfigJson { pub torch_dtype: DataType, } -type Blob = dyn 'static + AsRef<[u8]>; - #[derive(Clone)] pub struct Storage { - data: Arc<Blob>, + data: Arc<dyn AsRef<[u8]>>, range: Range<usize>, } +impl AsRef<[u8]> for Storage { + #[inline] + fn as_ref(&self) -> &[u8] { + &self.data.as_ref().as_ref()[self.range.clone()] + } +} + impl Storage { #[inline] - pub fn new(data: Arc<Blob>, offset: usize, len: usize) -> Self { + pub fn new(data: Arc<dyn AsRef<[u8]>>, offset: usize, len: usize) -> Self { Self { data, range: offset..offset + len, @@ -100,9 +105,4 @@ impl Storage { range: 0..len, } } - - #[inline] - pub fn as_slice(&self) -> &[u8] { - &self.data.as_ref().as_ref()[self.range.clone()] - } } diff --git a/model-parameters/src/memory/cast.rs b/model-parameters/src/memory/cast.rs index f38dcda5..0eff62ac 100644 --- a/model-parameters/src/memory/cast.rs +++ b/model-parameters/src/memory/cast.rs @@ -43,7 +43,7 @@ fn cast(src: Tensor<Storage>, new_dtype: DataType) -> Tensor<Storage> { return src; } - let src_data = src.physical().as_slice(); + let src_data = src.as_slice(); let mut data = vec![0u8; src.size() * new_dtype.size()]; macro_rules! cast { diff --git a/model-parameters/src/memory/mod.rs b/model-parameters/src/memory/mod.rs index 7b4957cb..9b911a7a 100644 --- a/model-parameters/src/memory/mod.rs +++ b/model-parameters/src/memory/mod.rs @@ -108,11 +108,7 @@ impl Llama2 for Memory { let dt = self.config.torch_dtype.size(); let mut physical = self.layers[layer].w_qkv.physical().clone(); physical.range.end = physical.range.start + d * d * dt; - Tensor::new( - self.config.torch_dtype, - Shape::from_slice(&[d as _, d as _]), - physical, - ) + Tensor::new(self.config.torch_dtype, &[d, d], physical) } #[inline] @@ -123,11 +119,7 @@ impl Llama2 for Memory { let mut physical = self.layers[layer].w_qkv.physical().clone(); physical.range.start += d * d * dt; physical.range.end = physical.range.start + dkv * d * dt; - Tensor::new( - self.config.torch_dtype, - Shape::from_slice(&[dkv as _, d as _]), - physical, - ) + Tensor::new(self.config.torch_dtype, &[dkv, d], physical) } #[inline] @@ -138,11 +130,7 @@ impl Llama2 for Memory { let mut physical = self.layers[layer].w_qkv.physical().clone(); physical.range.start += (d + dkv) * d * dt; physical.range.end = physical.range.start + dkv * d * dt; - Tensor::new( - self.config.torch_dtype, - Shape::from_slice(&[dkv as _, d as _]), - physical, - ) + Tensor::new(self.config.torch_dtype, &[dkv, d], physical) } #[inline] @@ -199,11 +187,12 @@ fn concat0(tensors: &[&Tensor<Storage>]) -> Tensor<Storage> { let mut offset = 0; for t in tensors { let len = t.size() * data_type.size(); - data[offset..][..len].copy_from_slice(t.physical().as_slice()); + data[offset..][..len].copy_from_slice(t.as_slice()); offset += len; } - Tensor::new(data_type, shape, Storage::from_blob(data)) + let shape = shape.iter().map(|&d| d as usize).collect::<Vec<_>>(); + Tensor::new(data_type, &shape, Storage::from_blob(data)) } #[test] diff --git a/model-parameters/src/memory/safe_tensors.rs b/model-parameters/src/memory/safe_tensors.rs index 528098c5..89f69a3a 100644 --- a/model-parameters/src/memory/safe_tensors.rs +++ b/model-parameters/src/memory/safe_tensors.rs @@ -49,7 +49,7 @@ impl Memory { debug_assert_eq!(data_type, config.torch_dtype); Tensor::new( data_type, - info.shape.iter().map(|&d| d as _).collect(), + &info.shape, Storage::new(mmap.clone(), start, end - start), ) }; diff --git a/model-parameters/src/save.rs b/model-parameters/src/save.rs index 8080748b..c6918249 100644 --- a/model-parameters/src/save.rs +++ b/model-parameters/src/save.rs @@ -52,7 +52,7 @@ pub fn save(model: &dyn Llama2, dir: impl AsRef<Path>) -> io::Result<()> { shape: tensor.shape().iter().map(|&d| d as _).collect(), data_offsets: { let start = offset; - offset += tensor.physical().as_slice().len(); + offset += tensor.as_slice().len(); (start, offset) }, }; @@ -112,17 +112,17 @@ pub fn save(model: &dyn Llama2, dir: impl AsRef<Path>) -> io::Result<()> { write.write_all(&[32])?; } } - write.write_all(model.embed_tokens().physical().as_slice())?; + write.write_all(model.embed_tokens().as_slice())?; for layer in 0..model.num_hidden_layers() { - write.write_all(model.input_layernorm(layer).physical().as_slice())?; - write.write_all(model.w_qkv(layer).physical().as_slice())?; - write.write_all(model.self_attn_o_proj(layer).physical().as_slice())?; - write.write_all(model.post_attention_layernorm(layer).physical().as_slice())?; - write.write_all(model.mlp_gate(layer).physical().as_slice())?; - write.write_all(model.mlp_down(layer).physical().as_slice())?; - write.write_all(model.mlp_up(layer).physical().as_slice())?; + write.write_all(model.input_layernorm(layer).as_slice())?; + write.write_all(model.w_qkv(layer).as_slice())?; + write.write_all(model.self_attn_o_proj(layer).as_slice())?; + write.write_all(model.post_attention_layernorm(layer).as_slice())?; + write.write_all(model.mlp_gate(layer).as_slice())?; + write.write_all(model.mlp_down(layer).as_slice())?; + write.write_all(model.mlp_up(layer).as_slice())?; } - write.write_all(model.model_norm().physical().as_slice())?; - write.write_all(model.lm_head().physical().as_slice())?; + write.write_all(model.model_norm().as_slice())?; + write.write_all(model.lm_head().as_slice())?; Ok(()) } diff --git a/tensor/src/tensor.rs b/tensor/src/tensor.rs index 43f7db37..5d95f83e 100644 --- a/tensor/src/tensor.rs +++ b/tensor/src/tensor.rs @@ -11,7 +11,8 @@ pub struct Tensor<Physical> { } impl<Physical: Clone> Tensor<Physical> { - pub fn new(data_type: DataType, shape: Shape, physical: Physical) -> Self { + pub fn new(data_type: DataType, shape: &[usize], physical: Physical) -> Self { + let shape = Shape::from_iter(shape.iter().map(|&d| d as udim)); Self { data_type, pattern: Pattern::from_shape(&shape), @@ -35,6 +36,11 @@ impl<Physical: Clone> Tensor<Physical> { &self.physical } + #[inline] + pub fn physical_mut(&mut self) -> &mut Physical { + &mut self.physical + } + #[inline] pub fn size(&self) -> usize { self.shape.iter().map(|&d| d as usize).product() @@ -85,6 +91,18 @@ impl<Physical: Clone> Tensor<Physical> { } } +impl<Physical: AsRef<[u8]>> Tensor<Physical> { + pub fn as_slice(&self) -> &[u8] { + self.physical.as_ref() + } +} + +impl<Physical: AsMut<[u8]>> Tensor<Physical> { + pub fn as_mut_slice(&mut self) -> &mut [u8] { + self.physical.as_mut() + } +} + pub type Shape = SmallVec<[udim; 4]>; pub type Affine = DMatrix<idim>; @@ -108,7 +126,7 @@ fn test() { use super::Transpose; use smallvec::smallvec; - let t = Tensor::new(DataType::F32, Shape::from_slice(&[2, 3, 4, 5]), ()); + let t = Tensor::new(DataType::F32, &[2, 3, 4, 5], ()); assert_eq!(t.shape(), &[2, 3, 4, 5]); assert_eq!(t.pattern.0.as_slice(), &[60, 20, 5, 1, 0]); assert_eq!(t.is_contiguous(), true); diff --git a/transformer-cpu/src/lib.rs b/transformer-cpu/src/lib.rs index ff1097be..08f50791 100644 --- a/transformer-cpu/src/lib.rs +++ b/transformer-cpu/src/lib.rs @@ -38,11 +38,7 @@ impl Transformer { let dt = self.model.data_type(); let mut a = vec![0u8; seq_len * d * dt.size()]; - gather( - &mut a, - self.model.embed_tokens().physical().as_slice(), - tokens, - ); + gather(&mut a, self.model.embed_tokens().as_slice(), tokens); let mut b = vec![0u8; seq_len * d * dt.size()]; for l in 0..self.model.num_hidden_layers() { @@ -51,7 +47,7 @@ impl Transformer { let o = &mut b; let x = &a; let w = self.model.input_layernorm(l); - let w = w.physical().as_slice(); + let w = w.as_slice(); let theta = self.model.rope_theta(); rms_norm(o, x, w, theta, dt); }