Skip to content

Commit

Permalink
feat(model-parameters): 保存 qkv 相连的模型
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 20, 2024
1 parent dd908d0 commit bae589e
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 146 deletions.
54 changes: 31 additions & 23 deletions model-parameters/src/memory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod safe_tensors;

use crate::{ConfigJson, DataType, Llama2, Storage};
use common::utok;
use tensor::{udim, Shape, Tensor};
use tensor::{Shape, Tensor};

pub use safe_tensors::SafeTensorError;
pub(crate) use safe_tensors::SafeTensorHeaderJson;
Expand Down Expand Up @@ -101,28 +101,11 @@ impl Llama2 for Memory {

#[inline]
fn w_qkv(&self, layer: usize) -> Tensor<Storage> {
let q = &self.layers[layer].self_attn_q_proj;
let k = &self.layers[layer].self_attn_k_proj;
let v = &self.layers[layer].self_attn_v_proj;
let d = self.hidden_size() as udim;
let dkv =
(self.hidden_size() * self.num_key_value_heads() / self.num_attention_heads()) as udim;
let dt = self.config.torch_dtype.size();
debug_assert_eq!(q.shape(), &[d, d]);
debug_assert_eq!(k.shape(), &[dkv, d]);
debug_assert_eq!(v.shape(), &[dkv, d]);
let size = (q.size() + k.size() + v.size()) * dt;
let mut data = vec![0u8; size];
let (q_, kv_) = data.split_at_mut(q.size() * dt);
let (k_, v_) = kv_.split_at_mut(k.size() * dt);
q_.copy_from_slice(q.physical().as_slice());
k_.copy_from_slice(k.physical().as_slice());
v_.copy_from_slice(v.physical().as_slice());
Tensor::new(
self.config.torch_dtype,
Shape::from_vec(vec![d + dkv + dkv, d]),
Storage::from_blob(data),
)
concat0(&[
&self.layers[layer].self_attn_q_proj,
&self.layers[layer].self_attn_k_proj,
&self.layers[layer].self_attn_v_proj,
])
}

#[inline]
Expand Down Expand Up @@ -176,6 +159,31 @@ impl Llama2 for Memory {
}
}

fn concat0(tensors: &[&Tensor<Storage>]) -> Tensor<Storage> {
assert!(!tensors.is_empty());
let data_type = tensors[0].data_type();
let mut shape = Shape::from_slice(tensors[0].shape());

debug_assert!(tensors
.iter()
.all(|t| t.data_type() == data_type && t.shape()[1..] == shape[1..]));

for t in &tensors[1..] {
shape[0] += t.shape()[0];
}

let size = shape.iter().map(|&d| d as usize).product::<usize>() * data_type.size();
let mut data = vec![0u8; size];
let mut offset = 0;
for t in tensors {
let len = t.size() * data_type.size();
data[offset..][..len].copy_from_slice(t.physical().as_slice());
offset += len;
}

Tensor::new(data_type, shape, Storage::from_blob(data))
}

#[test]
fn test_load() {
use std::time::Instant;
Expand Down
13 changes: 0 additions & 13 deletions model-parameters/src/memory/safe_tensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,3 @@ pub(crate) struct SafeTensorHeaderJson {
#[serde(rename = "__metadata__")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}

#[test]
fn test_meta() {
let header = SafeTensorHeaderJson {
tensors: HashMap::new(),
meta: Some(
[("concat_qkv".to_string(), serde_json::Value::Bool(true))]
.into_iter()
.collect(),
),
};
println!("{}", serde_json::to_string_pretty(&header).unwrap());
}
154 changes: 44 additions & 110 deletions model-parameters/src/save.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::{memory::SafeTensorHeaderJson, ConfigJson, DataType, Llama2};
use crate::{memory::SafeTensorHeaderJson, ConfigJson, DataType, Llama2, Storage};
use safetensors::{tensor::TensorInfo, Dtype};
use std::{
collections::HashMap,
fs,
io::{self, BufWriter, Write},
path::Path,
};
use tensor::Tensor;

pub fn save(model: &dyn Llama2, dir: impl AsRef<Path>) -> io::Result<()> {
let dir = dir.as_ref();
Expand All @@ -26,141 +27,76 @@ pub fn save(model: &dyn Llama2, dir: impl AsRef<Path>) -> io::Result<()> {
})?;
fs::write(dir.join("config.json"), config)?;

let dtype = match model.data_type() {
DataType::F16 => Dtype::F16,
DataType::BF16 => Dtype::BF16,
DataType::F32 => Dtype::F32,
_ => todo!(),
};
let d = model.hidden_size();
let dkv = d * model.num_key_value_heads() / model.num_attention_heads();
let di = model.intermediate_size();
let dv = model.vocab_size();

struct Offset(usize);
impl Offset {
#[inline]
fn update(&mut self, len: usize) -> (usize, usize) {
let start = self.0;
self.0 += len;
(start, self.0)
}
}

let mut offset = Offset(0);
let mut offset = 0usize;
let mut header = SafeTensorHeaderJson {
tensors: HashMap::new(),
meta: None,
};

let mut tensor_info = |tensor: Tensor<Storage>| TensorInfo {
dtype: match tensor.data_type() {
DataType::Bool => Dtype::BOOL,
DataType::I8 => Dtype::I8,
DataType::I16 => Dtype::I16,
DataType::I32 => Dtype::I32,
DataType::I64 => Dtype::I64,
DataType::U8 => Dtype::U8,
DataType::U16 => Dtype::U16,
DataType::U32 => Dtype::U32,
DataType::U64 => Dtype::U64,
DataType::F16 => Dtype::F16,
DataType::BF16 => Dtype::BF16,
DataType::F32 => Dtype::F32,
DataType::F64 => Dtype::F64,
},
shape: tensor.shape().iter().map(|&d| d as _).collect(),
data_offsets: {
let start = offset;
offset += tensor.physical().as_slice().len();
(start, offset)
},
};

header.tensors.insert(
"model.embed_tokens.weight".into(),
TensorInfo {
dtype,
shape: vec![dv, d],
data_offsets: offset.update(model.embed_tokens().physical().as_slice().len()),
},
tensor_info(model.embed_tokens()),
);
for layer in 0..model.num_hidden_layers() {
header.tensors.insert(
format!("model.layers.{layer}.input_layernorm.weight"),
TensorInfo {
dtype,
shape: vec![d],
data_offsets: offset
.update(model.input_layernorm(layer).physical().as_slice().len()),
},
tensor_info(model.input_layernorm(layer)),
);
header.tensors.insert(
format!("model.layers.{layer}.self_attn.q_proj.weight"),
TensorInfo {
dtype,
shape: vec![d, d],
data_offsets: offset
.update(model.self_attn_q_proj(layer).physical().as_slice().len()),
},
);
header.tensors.insert(
format!("model.layers.{layer}.self_attn.k_proj.weight"),
TensorInfo {
dtype,
shape: vec![dkv, d],
data_offsets: offset
.update(model.self_attn_k_proj(layer).physical().as_slice().len()),
},
);
header.tensors.insert(
format!("model.layers.{layer}.self_attn.v_proj.weight"),
TensorInfo {
dtype,
shape: vec![dkv, d],
data_offsets: offset
.update(model.self_attn_v_proj(layer).physical().as_slice().len()),
},
format!("model.layers.{layer}.self_attn.qkv_proj.weight"),
tensor_info(model.w_qkv(layer)),
);
header.tensors.insert(
format!("model.layers.{layer}.self_attn.o_proj.weight"),
TensorInfo {
dtype,
shape: vec![d, d],
data_offsets: offset
.update(model.self_attn_o_proj(layer).physical().as_slice().len()),
},
tensor_info(model.self_attn_o_proj(layer)),
);
header.tensors.insert(
format!("model.layers.{layer}.post_attention_layernorm.weight"),
TensorInfo {
dtype,
shape: vec![d],
data_offsets: offset.update(
model
.post_attention_layernorm(layer)
.physical()
.as_slice()
.len(),
),
},
tensor_info(model.post_attention_layernorm(layer)),
);
header.tensors.insert(
format!("model.layers.{layer}.mlp.gate_proj.weight"),
TensorInfo {
dtype,
shape: vec![di, d],
data_offsets: offset.update(model.mlp_gate(layer).physical().as_slice().len()),
},
tensor_info(model.mlp_gate(layer)),
);
header.tensors.insert(
format!("model.layers.{layer}.mlp.down_proj.weight"),
TensorInfo {
dtype,
shape: vec![d, di],
data_offsets: offset.update(model.mlp_down(layer).physical().as_slice().len()),
},
tensor_info(model.mlp_down(layer)),
);
header.tensors.insert(
format!("model.layers.{layer}.mlp.up_proj.weight"),
TensorInfo {
dtype,
shape: vec![di, d],
data_offsets: offset.update(model.mlp_up(layer).physical().as_slice().len()),
},
tensor_info(model.mlp_up(layer)),
);
}
header.tensors.insert(
"model.norm.weight".into(),
TensorInfo {
dtype,
shape: vec![d],
data_offsets: offset.update(model.model_norm().physical().as_slice().len()),
},
);
header.tensors.insert(
"lm_head.weight".into(),
TensorInfo {
dtype,
shape: vec![dv, d],
data_offsets: offset.update(model.lm_head().physical().as_slice().len()),
},
);
header
.tensors
.insert("model.norm.weight".into(), tensor_info(model.model_norm()));
header
.tensors
.insert("lm_head.weight".into(), tensor_info(model.lm_head()));

let mut file = fs::File::create(dir.join("model.safetensors"))?;
let mut write = BufWriter::new(&mut file);
Expand All @@ -179,9 +115,7 @@ pub fn save(model: &dyn Llama2, dir: impl AsRef<Path>) -> io::Result<()> {
write.write_all(model.embed_tokens().physical().as_slice())?;
for layer in 0..model.num_hidden_layers() {
write.write_all(model.input_layernorm(layer).physical().as_slice())?;
write.write_all(model.self_attn_q_proj(layer).physical().as_slice())?;
write.write_all(model.self_attn_k_proj(layer).physical().as_slice())?;
write.write_all(model.self_attn_v_proj(layer).physical().as_slice())?;
write.write_all(model.w_qkv(layer).physical().as_slice())?;
write.write_all(model.self_attn_o_proj(layer).physical().as_slice())?;
write.write_all(model.post_attention_layernorm(layer).physical().as_slice())?;
write.write_all(model.mlp_gate(layer).physical().as_slice())?;
Expand Down

0 comments on commit bae589e

Please sign in to comment.