Skip to content

Commit

Permalink
refactor(llama): llama-legacy 中 cast 和 save 移动到 llama
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 29, 2024
1 parent 39e6d24 commit 6acad90
Show file tree
Hide file tree
Showing 14 changed files with 496 additions and 383 deletions.
7 changes: 4 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

63 changes: 0 additions & 63 deletions models/llama-legacy/src/cast.rs

This file was deleted.

3 changes: 0 additions & 3 deletions models/llama-legacy/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
mod cast;
mod memory;
mod safe_tensors;
mod save;
mod storage;

use common::utok;
Expand All @@ -10,7 +8,6 @@ mod distribute;

pub use distribute::{DistributeScheme, DistributedLayer, Distributer};
pub use memory::Memory;
pub use save::save;
pub use storage::Storage;

pub trait Llama2 {
Expand Down
20 changes: 0 additions & 20 deletions models/llama-legacy/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,23 +172,3 @@ impl Llama2 for Memory {
self.lm_head.clone()
}
}

#[test]
fn test_load() {
use std::time::Instant;

let Some(model_dir) = common::test_model::find() else {
return;
};
println!("model_dir: {}", model_dir.display());

let t0 = Instant::now();
let model = Memory::load_safetensors(model_dir).unwrap();
let t1 = Instant::now();
println!("mmap {:?}", t1 - t0);

let t0 = Instant::now();
let _inside_memory = Memory::cast(&model, DataType::F32);
let t1 = Instant::now();
println!("cast {:?}", t1 - t0);
}
117 changes: 0 additions & 117 deletions models/llama-legacy/src/save.rs

This file was deleted.

1 change: 1 addition & 0 deletions models/llama/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ common = { path = "../../common" }
tensor = { path = "../../tensor" }
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
rayon.workspace = true
59 changes: 59 additions & 0 deletions models/llama/src/cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use crate::{InferenceConfig, LayerStorage, Storage, Weight};
use common::{bf16, f16, Blob};
use tensor::{DataType, Tensor, Ty};

impl Storage {
pub fn cast(self, dt: DataType) -> Self {
if self.config.dt == dt {
return self;
}
Self {
config: InferenceConfig { dt, ..self.config },
embed_tokens: cast(self.embed_tokens, dt),
layers: self
.layers
.into_iter()
.map(|l| LayerStorage {
att_layernorm: cast(l.att_layernorm, dt),
att_qkv: cast(l.att_qkv, dt),
att_o: cast(l.att_o, dt),
mlp_layernorm: cast(l.mlp_layernorm, dt),
mlp_gate_up: cast(l.mlp_gate_up, dt),
mlp_down: cast(l.mlp_down, dt),
})
.collect(),
lm_layernorm: cast(self.lm_layernorm, dt),
lm_head: cast(self.lm_head, dt),
}
}
}

fn cast(src: Tensor<Weight>, dt: DataType) -> Tensor<Weight> {
match (src.data_type(), dt) {
(DataType::F16, DataType::BF16) => typed(src, |x: &f16| bf16::from_f32(x.to_f32())),
(DataType::F16, DataType::F32) => typed(src, |x: &f16| x.to_f32()),
(DataType::BF16, DataType::F16) => typed(src, |x: &bf16| f16::from_f32(x.to_f32())),
(DataType::BF16, DataType::F32) => typed(src, |x: &bf16| x.to_f32()),
(DataType::F32, DataType::F16) => typed(src, |x: &f32| f16::from_f32(*x)),
(DataType::F32, DataType::BF16) => typed(src, |x: &f32| bf16::from_f32(*x)),
_ => todo!(),
}
}

fn typed<T: Ty + Sync, U: Ty + Send>(
src: Tensor<Weight>,
cast: impl Fn(&T) -> U + Sync,
) -> Tensor<Weight> {
use rayon::iter::*;
use tensor::{reslice, reslice_mut};

assert_eq!(src.data_type(), T::DATA_TYPE);
let mut ans = Tensor::alloc(U::DATA_TYPE, src.shape(), Blob::new);

reslice(src.physical())
.par_iter()
.zip(reslice_mut(ans.physical_mut()))
.for_each(|(src, dst)| *dst = cast(src));

ans.map_physical(|b| b.into())
}
77 changes: 77 additions & 0 deletions models/llama/src/json.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use common::utok;
use tensor::DataType;

#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub(crate) struct ConfigJson {
pub bos_token_id: utok,
pub eos_token_id: utok,
pub hidden_size: usize,
pub intermediate_size: usize,
pub max_position_embeddings: usize,
pub num_attention_heads: usize,
pub num_hidden_layers: usize,
pub num_key_value_heads: usize,
pub vocab_size: usize,
#[serde(default = "default_rms_norm_eps")]
pub rms_norm_eps: f32,
#[serde(default = "default_rope_theta")]
pub rope_theta: f32,
pub torch_dtype: DataType,
}

#[inline(always)]
const fn default_rms_norm_eps() -> f32 {
1e-5
}

#[inline(always)]
const fn default_rope_theta() -> f32 {
1e4
}

macro_rules! convert {
(Dtype: $dtype:expr) => {{
use common::safe_tensors::Dtype;
use tensor::DataType;

match $dtype {
Dtype::BOOL => DataType::Bool,
Dtype::I8 => DataType::I8,
Dtype::I16 => DataType::I16,
Dtype::I32 => DataType::I32,
Dtype::I64 => DataType::I64,
Dtype::U8 => DataType::U8,
Dtype::U16 => DataType::U16,
Dtype::U32 => DataType::U32,
Dtype::U64 => DataType::U64,
Dtype::F16 => DataType::F16,
Dtype::BF16 => DataType::BF16,
Dtype::F32 => DataType::F32,
Dtype::F64 => DataType::F64,
_ => unreachable!(),
}
}};

(DataType: $data_type:expr) => {{
use common::safe_tensors::Dtype;
use tensor::DataType;

match $data_type {
DataType::Bool => Dtype::BOOL,
DataType::I8 => Dtype::I8,
DataType::I16 => Dtype::I16,
DataType::I32 => Dtype::I32,
DataType::I64 => Dtype::I64,
DataType::U8 => Dtype::U8,
DataType::U16 => Dtype::U16,
DataType::U32 => Dtype::U32,
DataType::U64 => Dtype::U64,
DataType::F16 => Dtype::F16,
DataType::BF16 => Dtype::BF16,
DataType::F32 => Dtype::F32,
DataType::F64 => Dtype::F64,
}
}};
}

pub(crate) use convert;
Loading

0 comments on commit 6acad90

Please sign in to comment.