Skip to content

Commit

Permalink
refactor(model-parameters): 基于 tensor 简化实现
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 20, 2024
1 parent 3d0c788 commit 619fc05
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 377 deletions.
4 changes: 0 additions & 4 deletions model-parameters/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@ authors = ["YdrMaster <[email protected]>"]
[dependencies]
common = { path = "../common" }
tensor = { path = "../tensor" }
log = "0.4"
half = "2.3"
rayon = "1.8"
memmap2 = "0.9"
safetensors = "0.4"
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }

[dev-dependencies]
env_logger = "0.11"
15 changes: 0 additions & 15 deletions model-parameters/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ mod memory;
mod save;
mod storage;

#[macro_use]
extern crate log;

use common::utok;
use storage::Storage;
use tensor::{DataType, Tensor};
Expand Down Expand Up @@ -79,15 +76,3 @@ struct ConfigJson {
pub rope_theta: f32,
pub torch_dtype: DataType,
}

struct LayerParamsOffset {
input_layernorm: usize,
self_attn_q_proj: usize,
self_attn_k_proj: usize,
self_attn_v_proj: usize,
self_attn_o_proj: usize,
post_attention_layernorm: usize,
mlp_gate: usize,
mlp_down: usize,
mlp_up: usize,
}
94 changes: 94 additions & 0 deletions model-parameters/src/memory/cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use super::Layer;
use crate::{storage::Storage, ConfigJson, DataType, Llama2, Memory};
use half::{bf16, f16};
use std::sync::Arc;
use tensor::Tensor;

impl Memory {
pub fn cast(src: &dyn Llama2, new_dtype: DataType) -> Self {
Self {
config: ConfigJson {
bos_token_id: src.bos_token_id(),
eos_token_id: src.eos_token_id(),
hidden_size: src.hidden_size(),
intermediate_size: src.intermediate_size(),
max_position_embeddings: src.max_position_embeddings(),
num_attention_heads: src.num_attention_heads(),
num_hidden_layers: src.num_hidden_layers(),
num_key_value_heads: src.num_key_value_heads(),
vocab_size: src.vocab_size(),
rms_norm_eps: src.rms_norm_eps(),
rope_theta: src.rope_theta(),
torch_dtype: new_dtype,
},
embed_tokens: cast(src.embed_tokens(), new_dtype),
layers: (0..src.num_hidden_layers())
.map(|l| Layer {
input_layernorm: cast(src.input_layernorm(l), new_dtype),
self_attn_q_proj: cast(src.self_attn_q_proj(l), new_dtype),
self_attn_k_proj: cast(src.self_attn_k_proj(l), new_dtype),
self_attn_v_proj: cast(src.self_attn_v_proj(l), new_dtype),
self_attn_o_proj: cast(src.self_attn_o_proj(l), new_dtype),
post_attention_layernorm: cast(src.post_attention_layernorm(l), new_dtype),
mlp_gate: cast(src.mlp_gate(l), new_dtype),
mlp_down: cast(src.mlp_down(l), new_dtype),
mlp_up: cast(src.mlp_up(l), new_dtype),
})
.collect(),
model_norm: cast(src.model_norm(), new_dtype),
lm_head: cast(src.lm_head(), new_dtype),
}
}
}

fn cast(src: Tensor<Storage>, new_dtype: DataType) -> Tensor<Storage> {
if src.data_type() == new_dtype {
return src;
}

let src_data = src.physical().as_slice();
let mut data = vec![0u8; src.size() * new_dtype.size()];

macro_rules! cast {
($f:expr; $src:expr, $src_ty:ty => $dst:expr, $dst_ty:ty) => {
use rayon::iter::*;
use std::{mem::size_of, slice::from_raw_parts, slice::from_raw_parts_mut};

let len = $src.len() / size_of::<$src_ty>();
debug_assert_eq!(len * size_of::<$dst_ty>(), $dst.len());
let src = unsafe { from_raw_parts($src.as_ptr() as *const $src_ty, len) };
let dst = unsafe { from_raw_parts_mut($dst.as_mut_ptr() as *mut $dst_ty, len) };

#[allow(clippy::redundant_closure_call)]
src.par_iter()
.zip(dst)
.for_each(|(src, dst)| *dst = $f(*src));
};
}

match (src.data_type(), new_dtype) {
(DataType::F16, DataType::BF16) => {
cast!(|x: f16| bf16::from_f32(x.to_f32()); src_data, f16 => &mut data, bf16);
}
(DataType::F16, DataType::F32) => {
cast!(|x: f16| x.to_f32(); src_data, f16 => &mut data, f32);
}
(DataType::BF16, DataType::F16) => {
cast!(|x: bf16| f16::from_f32(x.to_f32()); src_data, bf16 => &mut data, f16);
}
(DataType::BF16, DataType::F32) => {
cast!(|x: bf16| x.to_f32(); src_data, bf16 => &mut data, f32);
}
(DataType::F32, DataType::F16) => {
cast!(|x: f32| f16::from_f32(x); src_data, f32 => &mut data, f16);
}
(DataType::F32, DataType::BF16) => {
cast!(|x: f32| bf16::from_f32(x); src_data, f32 => &mut data, bf16);
}
_ => todo!(),
}

let len = data.len();
let pysical = Storage::new(Arc::new(data), 0, len);
unsafe { src.cast(new_dtype, pysical) }
}
115 changes: 0 additions & 115 deletions model-parameters/src/memory/inside_memory.rs

This file was deleted.

Loading

0 comments on commit 619fc05

Please sign in to comment.