From ff9360b1a98c83cdcf382128e9ec18a276f61a48 Mon Sep 17 00:00:00 2001
From: YdrMaster <ydrml@hotmail.com>
Date: Tue, 20 Feb 2024 17:12:38 +0800
Subject: [PATCH] =?UTF-8?q?feat(tensor):=20=E5=A6=82=E6=9E=9C=E5=BC=A0?=
 =?UTF-8?q?=E9=87=8F=E6=95=B0=E6=8D=AE=E5=9C=A8=E4=B8=BB=E5=AD=98=E4=B8=AD?=
 =?UTF-8?q?=EF=BC=8C=E5=85=81=E8=AE=B8=E7=9B=B4=E6=8E=A5=E8=AE=BF=E9=97=AE?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: YdrMaster <ydrml@hotmail.com>
---
 model-parameters/src/lib.rs                 | 18 ++++++++--------
 model-parameters/src/memory/cast.rs         |  2 +-
 model-parameters/src/memory/mod.rs          | 23 ++++++---------------
 model-parameters/src/memory/safe_tensors.rs |  2 +-
 model-parameters/src/save.rs                | 22 ++++++++++----------
 tensor/src/tensor.rs                        | 22 ++++++++++++++++++--
 transformer-cpu/src/lib.rs                  |  8 ++-----
 7 files changed, 50 insertions(+), 47 deletions(-)

diff --git a/model-parameters/src/lib.rs b/model-parameters/src/lib.rs
index f83b8f2a..5e328a2f 100644
--- a/model-parameters/src/lib.rs
+++ b/model-parameters/src/lib.rs
@@ -75,17 +75,22 @@ struct ConfigJson {
     pub torch_dtype: DataType,
 }
 
-type Blob = dyn 'static + AsRef<[u8]>;
-
 #[derive(Clone)]
 pub struct Storage {
-    data: Arc<Blob>,
+    data: Arc<dyn AsRef<[u8]>>,
     range: Range<usize>,
 }
 
+impl AsRef<[u8]> for Storage {
+    #[inline]
+    fn as_ref(&self) -> &[u8] {
+        &self.data.as_ref().as_ref()[self.range.clone()]
+    }
+}
+
 impl Storage {
     #[inline]
-    pub fn new(data: Arc<Blob>, offset: usize, len: usize) -> Self {
+    pub fn new(data: Arc<dyn AsRef<[u8]>>, offset: usize, len: usize) -> Self {
         Self {
             data,
             range: offset..offset + len,
@@ -100,9 +105,4 @@ impl Storage {
             range: 0..len,
         }
     }
-
-    #[inline]
-    pub fn as_slice(&self) -> &[u8] {
-        &self.data.as_ref().as_ref()[self.range.clone()]
-    }
 }
diff --git a/model-parameters/src/memory/cast.rs b/model-parameters/src/memory/cast.rs
index f38dcda5..0eff62ac 100644
--- a/model-parameters/src/memory/cast.rs
+++ b/model-parameters/src/memory/cast.rs
@@ -43,7 +43,7 @@ fn cast(src: Tensor<Storage>, new_dtype: DataType) -> Tensor<Storage> {
         return src;
     }
 
-    let src_data = src.physical().as_slice();
+    let src_data = src.as_slice();
     let mut data = vec![0u8; src.size() * new_dtype.size()];
 
     macro_rules! cast {
diff --git a/model-parameters/src/memory/mod.rs b/model-parameters/src/memory/mod.rs
index 7b4957cb..9b911a7a 100644
--- a/model-parameters/src/memory/mod.rs
+++ b/model-parameters/src/memory/mod.rs
@@ -108,11 +108,7 @@ impl Llama2 for Memory {
         let dt = self.config.torch_dtype.size();
         let mut physical = self.layers[layer].w_qkv.physical().clone();
         physical.range.end = physical.range.start + d * d * dt;
-        Tensor::new(
-            self.config.torch_dtype,
-            Shape::from_slice(&[d as _, d as _]),
-            physical,
-        )
+        Tensor::new(self.config.torch_dtype, &[d, d], physical)
     }
 
     #[inline]
@@ -123,11 +119,7 @@ impl Llama2 for Memory {
         let mut physical = self.layers[layer].w_qkv.physical().clone();
         physical.range.start += d * d * dt;
         physical.range.end = physical.range.start + dkv * d * dt;
-        Tensor::new(
-            self.config.torch_dtype,
-            Shape::from_slice(&[dkv as _, d as _]),
-            physical,
-        )
+        Tensor::new(self.config.torch_dtype, &[dkv, d], physical)
     }
 
     #[inline]
@@ -138,11 +130,7 @@ impl Llama2 for Memory {
         let mut physical = self.layers[layer].w_qkv.physical().clone();
         physical.range.start += (d + dkv) * d * dt;
         physical.range.end = physical.range.start + dkv * d * dt;
-        Tensor::new(
-            self.config.torch_dtype,
-            Shape::from_slice(&[dkv as _, d as _]),
-            physical,
-        )
+        Tensor::new(self.config.torch_dtype, &[dkv, d], physical)
     }
 
     #[inline]
@@ -199,11 +187,12 @@ fn concat0(tensors: &[&Tensor<Storage>]) -> Tensor<Storage> {
     let mut offset = 0;
     for t in tensors {
         let len = t.size() * data_type.size();
-        data[offset..][..len].copy_from_slice(t.physical().as_slice());
+        data[offset..][..len].copy_from_slice(t.as_slice());
         offset += len;
     }
 
-    Tensor::new(data_type, shape, Storage::from_blob(data))
+    let shape = shape.iter().map(|&d| d as usize).collect::<Vec<_>>();
+    Tensor::new(data_type, &shape, Storage::from_blob(data))
 }
 
 #[test]
diff --git a/model-parameters/src/memory/safe_tensors.rs b/model-parameters/src/memory/safe_tensors.rs
index 528098c5..89f69a3a 100644
--- a/model-parameters/src/memory/safe_tensors.rs
+++ b/model-parameters/src/memory/safe_tensors.rs
@@ -49,7 +49,7 @@ impl Memory {
             debug_assert_eq!(data_type, config.torch_dtype);
             Tensor::new(
                 data_type,
-                info.shape.iter().map(|&d| d as _).collect(),
+                &info.shape,
                 Storage::new(mmap.clone(), start, end - start),
             )
         };
diff --git a/model-parameters/src/save.rs b/model-parameters/src/save.rs
index 8080748b..c6918249 100644
--- a/model-parameters/src/save.rs
+++ b/model-parameters/src/save.rs
@@ -52,7 +52,7 @@ pub fn save(model: &dyn Llama2, dir: impl AsRef<Path>) -> io::Result<()> {
         shape: tensor.shape().iter().map(|&d| d as _).collect(),
         data_offsets: {
             let start = offset;
-            offset += tensor.physical().as_slice().len();
+            offset += tensor.as_slice().len();
             (start, offset)
         },
     };
@@ -112,17 +112,17 @@ pub fn save(model: &dyn Llama2, dir: impl AsRef<Path>) -> io::Result<()> {
             write.write_all(&[32])?;
         }
     }
-    write.write_all(model.embed_tokens().physical().as_slice())?;
+    write.write_all(model.embed_tokens().as_slice())?;
     for layer in 0..model.num_hidden_layers() {
-        write.write_all(model.input_layernorm(layer).physical().as_slice())?;
-        write.write_all(model.w_qkv(layer).physical().as_slice())?;
-        write.write_all(model.self_attn_o_proj(layer).physical().as_slice())?;
-        write.write_all(model.post_attention_layernorm(layer).physical().as_slice())?;
-        write.write_all(model.mlp_gate(layer).physical().as_slice())?;
-        write.write_all(model.mlp_down(layer).physical().as_slice())?;
-        write.write_all(model.mlp_up(layer).physical().as_slice())?;
+        write.write_all(model.input_layernorm(layer).as_slice())?;
+        write.write_all(model.w_qkv(layer).as_slice())?;
+        write.write_all(model.self_attn_o_proj(layer).as_slice())?;
+        write.write_all(model.post_attention_layernorm(layer).as_slice())?;
+        write.write_all(model.mlp_gate(layer).as_slice())?;
+        write.write_all(model.mlp_down(layer).as_slice())?;
+        write.write_all(model.mlp_up(layer).as_slice())?;
     }
-    write.write_all(model.model_norm().physical().as_slice())?;
-    write.write_all(model.lm_head().physical().as_slice())?;
+    write.write_all(model.model_norm().as_slice())?;
+    write.write_all(model.lm_head().as_slice())?;
     Ok(())
 }
diff --git a/tensor/src/tensor.rs b/tensor/src/tensor.rs
index 43f7db37..5d95f83e 100644
--- a/tensor/src/tensor.rs
+++ b/tensor/src/tensor.rs
@@ -11,7 +11,8 @@ pub struct Tensor<Physical> {
 }
 
 impl<Physical: Clone> Tensor<Physical> {
-    pub fn new(data_type: DataType, shape: Shape, physical: Physical) -> Self {
+    pub fn new(data_type: DataType, shape: &[usize], physical: Physical) -> Self {
+        let shape = Shape::from_iter(shape.iter().map(|&d| d as udim));
         Self {
             data_type,
             pattern: Pattern::from_shape(&shape),
@@ -35,6 +36,11 @@ impl<Physical: Clone> Tensor<Physical> {
         &self.physical
     }
 
+    #[inline]
+    pub fn physical_mut(&mut self) -> &mut Physical {
+        &mut self.physical
+    }
+
     #[inline]
     pub fn size(&self) -> usize {
         self.shape.iter().map(|&d| d as usize).product()
@@ -85,6 +91,18 @@ impl<Physical: Clone> Tensor<Physical> {
     }
 }
 
+impl<Physical: AsRef<[u8]>> Tensor<Physical> {
+    pub fn as_slice(&self) -> &[u8] {
+        self.physical.as_ref()
+    }
+}
+
+impl<Physical: AsMut<[u8]>> Tensor<Physical> {
+    pub fn as_mut_slice(&mut self) -> &mut [u8] {
+        self.physical.as_mut()
+    }
+}
+
 pub type Shape = SmallVec<[udim; 4]>;
 pub type Affine = DMatrix<idim>;
 
@@ -108,7 +126,7 @@ fn test() {
     use super::Transpose;
     use smallvec::smallvec;
 
-    let t = Tensor::new(DataType::F32, Shape::from_slice(&[2, 3, 4, 5]), ());
+    let t = Tensor::new(DataType::F32, &[2, 3, 4, 5], ());
     assert_eq!(t.shape(), &[2, 3, 4, 5]);
     assert_eq!(t.pattern.0.as_slice(), &[60, 20, 5, 1, 0]);
     assert_eq!(t.is_contiguous(), true);
diff --git a/transformer-cpu/src/lib.rs b/transformer-cpu/src/lib.rs
index ff1097be..08f50791 100644
--- a/transformer-cpu/src/lib.rs
+++ b/transformer-cpu/src/lib.rs
@@ -38,11 +38,7 @@ impl Transformer {
         let dt = self.model.data_type();
 
         let mut a = vec![0u8; seq_len * d * dt.size()];
-        gather(
-            &mut a,
-            self.model.embed_tokens().physical().as_slice(),
-            tokens,
-        );
+        gather(&mut a, self.model.embed_tokens().as_slice(), tokens);
 
         let mut b = vec![0u8; seq_len * d * dt.size()];
         for l in 0..self.model.num_hidden_layers() {
@@ -51,7 +47,7 @@ impl Transformer {
                 let o = &mut b;
                 let x = &a;
                 let w = self.model.input_layernorm(l);
-                let w = w.physical().as_slice();
+                let w = w.as_slice();
                 let theta = self.model.rope_theta();
                 rms_norm(o, x, w, theta, dt);
             }