From bae589e7429d52a6603c4c0a025b57bf7bb9d7e1 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Tue, 20 Feb 2024 16:09:50 +0800 Subject: [PATCH] =?UTF-8?q?feat(model-parameters):=20=E4=BF=9D=E5=AD=98=20?= =?UTF-8?q?qkv=20=E7=9B=B8=E8=BF=9E=E7=9A=84=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- model-parameters/src/memory/mod.rs | 54 ++++--- model-parameters/src/memory/safe_tensors.rs | 13 -- model-parameters/src/save.rs | 154 ++++++-------------- 3 files changed, 75 insertions(+), 146 deletions(-) diff --git a/model-parameters/src/memory/mod.rs b/model-parameters/src/memory/mod.rs index a0bc2e4d..373caf8c 100644 --- a/model-parameters/src/memory/mod.rs +++ b/model-parameters/src/memory/mod.rs @@ -3,7 +3,7 @@ mod safe_tensors; use crate::{ConfigJson, DataType, Llama2, Storage}; use common::utok; -use tensor::{udim, Shape, Tensor}; +use tensor::{Shape, Tensor}; pub use safe_tensors::SafeTensorError; pub(crate) use safe_tensors::SafeTensorHeaderJson; @@ -101,28 +101,11 @@ impl Llama2 for Memory { #[inline] fn w_qkv(&self, layer: usize) -> Tensor { - let q = &self.layers[layer].self_attn_q_proj; - let k = &self.layers[layer].self_attn_k_proj; - let v = &self.layers[layer].self_attn_v_proj; - let d = self.hidden_size() as udim; - let dkv = - (self.hidden_size() * self.num_key_value_heads() / self.num_attention_heads()) as udim; - let dt = self.config.torch_dtype.size(); - debug_assert_eq!(q.shape(), &[d, d]); - debug_assert_eq!(k.shape(), &[dkv, d]); - debug_assert_eq!(v.shape(), &[dkv, d]); - let size = (q.size() + k.size() + v.size()) * dt; - let mut data = vec![0u8; size]; - let (q_, kv_) = data.split_at_mut(q.size() * dt); - let (k_, v_) = kv_.split_at_mut(k.size() * dt); - q_.copy_from_slice(q.physical().as_slice()); - k_.copy_from_slice(k.physical().as_slice()); - v_.copy_from_slice(v.physical().as_slice()); - Tensor::new( - self.config.torch_dtype, - Shape::from_vec(vec![d + dkv + dkv, d]), - Storage::from_blob(data), - ) + concat0(&[ + &self.layers[layer].self_attn_q_proj, + &self.layers[layer].self_attn_k_proj, + &self.layers[layer].self_attn_v_proj, + ]) } #[inline] @@ -176,6 +159,31 @@ impl Llama2 for Memory { } } +fn concat0(tensors: &[&Tensor]) -> Tensor { + assert!(!tensors.is_empty()); + let data_type = tensors[0].data_type(); + let mut shape = Shape::from_slice(tensors[0].shape()); + + debug_assert!(tensors + .iter() + .all(|t| t.data_type() == data_type && t.shape()[1..] == shape[1..])); + + for t in &tensors[1..] { + shape[0] += t.shape()[0]; + } + + let size = shape.iter().map(|&d| d as usize).product::() * data_type.size(); + let mut data = vec![0u8; size]; + let mut offset = 0; + for t in tensors { + let len = t.size() * data_type.size(); + data[offset..][..len].copy_from_slice(t.physical().as_slice()); + offset += len; + } + + Tensor::new(data_type, shape, Storage::from_blob(data)) +} + #[test] fn test_load() { use std::time::Instant; diff --git a/model-parameters/src/memory/safe_tensors.rs b/model-parameters/src/memory/safe_tensors.rs index 40bf91e9..5c491a4a 100644 --- a/model-parameters/src/memory/safe_tensors.rs +++ b/model-parameters/src/memory/safe_tensors.rs @@ -84,16 +84,3 @@ pub(crate) struct SafeTensorHeaderJson { #[serde(rename = "__metadata__")] pub meta: Option>, } - -#[test] -fn test_meta() { - let header = SafeTensorHeaderJson { - tensors: HashMap::new(), - meta: Some( - [("concat_qkv".to_string(), serde_json::Value::Bool(true))] - .into_iter() - .collect(), - ), - }; - println!("{}", serde_json::to_string_pretty(&header).unwrap()); -} diff --git a/model-parameters/src/save.rs b/model-parameters/src/save.rs index 2aee65fd..8080748b 100644 --- a/model-parameters/src/save.rs +++ b/model-parameters/src/save.rs @@ -1,4 +1,4 @@ -use crate::{memory::SafeTensorHeaderJson, ConfigJson, DataType, Llama2}; +use crate::{memory::SafeTensorHeaderJson, ConfigJson, DataType, Llama2, Storage}; use safetensors::{tensor::TensorInfo, Dtype}; use std::{ collections::HashMap, @@ -6,6 +6,7 @@ use std::{ io::{self, BufWriter, Write}, path::Path, }; +use tensor::Tensor; pub fn save(model: &dyn Llama2, dir: impl AsRef) -> io::Result<()> { let dir = dir.as_ref(); @@ -26,141 +27,76 @@ pub fn save(model: &dyn Llama2, dir: impl AsRef) -> io::Result<()> { })?; fs::write(dir.join("config.json"), config)?; - let dtype = match model.data_type() { - DataType::F16 => Dtype::F16, - DataType::BF16 => Dtype::BF16, - DataType::F32 => Dtype::F32, - _ => todo!(), - }; - let d = model.hidden_size(); - let dkv = d * model.num_key_value_heads() / model.num_attention_heads(); - let di = model.intermediate_size(); - let dv = model.vocab_size(); - - struct Offset(usize); - impl Offset { - #[inline] - fn update(&mut self, len: usize) -> (usize, usize) { - let start = self.0; - self.0 += len; - (start, self.0) - } - } - - let mut offset = Offset(0); + let mut offset = 0usize; let mut header = SafeTensorHeaderJson { tensors: HashMap::new(), meta: None, }; + + let mut tensor_info = |tensor: Tensor| TensorInfo { + dtype: match tensor.data_type() { + DataType::Bool => Dtype::BOOL, + DataType::I8 => Dtype::I8, + DataType::I16 => Dtype::I16, + DataType::I32 => Dtype::I32, + DataType::I64 => Dtype::I64, + DataType::U8 => Dtype::U8, + DataType::U16 => Dtype::U16, + DataType::U32 => Dtype::U32, + DataType::U64 => Dtype::U64, + DataType::F16 => Dtype::F16, + DataType::BF16 => Dtype::BF16, + DataType::F32 => Dtype::F32, + DataType::F64 => Dtype::F64, + }, + shape: tensor.shape().iter().map(|&d| d as _).collect(), + data_offsets: { + let start = offset; + offset += tensor.physical().as_slice().len(); + (start, offset) + }, + }; + header.tensors.insert( "model.embed_tokens.weight".into(), - TensorInfo { - dtype, - shape: vec![dv, d], - data_offsets: offset.update(model.embed_tokens().physical().as_slice().len()), - }, + tensor_info(model.embed_tokens()), ); for layer in 0..model.num_hidden_layers() { header.tensors.insert( format!("model.layers.{layer}.input_layernorm.weight"), - TensorInfo { - dtype, - shape: vec![d], - data_offsets: offset - .update(model.input_layernorm(layer).physical().as_slice().len()), - }, + tensor_info(model.input_layernorm(layer)), ); header.tensors.insert( - format!("model.layers.{layer}.self_attn.q_proj.weight"), - TensorInfo { - dtype, - shape: vec![d, d], - data_offsets: offset - .update(model.self_attn_q_proj(layer).physical().as_slice().len()), - }, - ); - header.tensors.insert( - format!("model.layers.{layer}.self_attn.k_proj.weight"), - TensorInfo { - dtype, - shape: vec![dkv, d], - data_offsets: offset - .update(model.self_attn_k_proj(layer).physical().as_slice().len()), - }, - ); - header.tensors.insert( - format!("model.layers.{layer}.self_attn.v_proj.weight"), - TensorInfo { - dtype, - shape: vec![dkv, d], - data_offsets: offset - .update(model.self_attn_v_proj(layer).physical().as_slice().len()), - }, + format!("model.layers.{layer}.self_attn.qkv_proj.weight"), + tensor_info(model.w_qkv(layer)), ); header.tensors.insert( format!("model.layers.{layer}.self_attn.o_proj.weight"), - TensorInfo { - dtype, - shape: vec![d, d], - data_offsets: offset - .update(model.self_attn_o_proj(layer).physical().as_slice().len()), - }, + tensor_info(model.self_attn_o_proj(layer)), ); header.tensors.insert( format!("model.layers.{layer}.post_attention_layernorm.weight"), - TensorInfo { - dtype, - shape: vec![d], - data_offsets: offset.update( - model - .post_attention_layernorm(layer) - .physical() - .as_slice() - .len(), - ), - }, + tensor_info(model.post_attention_layernorm(layer)), ); header.tensors.insert( format!("model.layers.{layer}.mlp.gate_proj.weight"), - TensorInfo { - dtype, - shape: vec![di, d], - data_offsets: offset.update(model.mlp_gate(layer).physical().as_slice().len()), - }, + tensor_info(model.mlp_gate(layer)), ); header.tensors.insert( format!("model.layers.{layer}.mlp.down_proj.weight"), - TensorInfo { - dtype, - shape: vec![d, di], - data_offsets: offset.update(model.mlp_down(layer).physical().as_slice().len()), - }, + tensor_info(model.mlp_down(layer)), ); header.tensors.insert( format!("model.layers.{layer}.mlp.up_proj.weight"), - TensorInfo { - dtype, - shape: vec![di, d], - data_offsets: offset.update(model.mlp_up(layer).physical().as_slice().len()), - }, + tensor_info(model.mlp_up(layer)), ); } - header.tensors.insert( - "model.norm.weight".into(), - TensorInfo { - dtype, - shape: vec![d], - data_offsets: offset.update(model.model_norm().physical().as_slice().len()), - }, - ); - header.tensors.insert( - "lm_head.weight".into(), - TensorInfo { - dtype, - shape: vec![dv, d], - data_offsets: offset.update(model.lm_head().physical().as_slice().len()), - }, - ); + header + .tensors + .insert("model.norm.weight".into(), tensor_info(model.model_norm())); + header + .tensors + .insert("lm_head.weight".into(), tensor_info(model.lm_head())); let mut file = fs::File::create(dir.join("model.safetensors"))?; let mut write = BufWriter::new(&mut file); @@ -179,9 +115,7 @@ pub fn save(model: &dyn Llama2, dir: impl AsRef) -> io::Result<()> { write.write_all(model.embed_tokens().physical().as_slice())?; for layer in 0..model.num_hidden_layers() { write.write_all(model.input_layernorm(layer).physical().as_slice())?; - write.write_all(model.self_attn_q_proj(layer).physical().as_slice())?; - write.write_all(model.self_attn_k_proj(layer).physical().as_slice())?; - write.write_all(model.self_attn_v_proj(layer).physical().as_slice())?; + write.write_all(model.w_qkv(layer).physical().as_slice())?; write.write_all(model.self_attn_o_proj(layer).physical().as_slice())?; write.write_all(model.post_attention_layernorm(layer).physical().as_slice())?; write.write_all(model.mlp_gate(layer).physical().as_slice())?;