-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(llama): llama-legacy 中 cast 和 save 移动到 llama
Signed-off-by: YdrMaster <[email protected]>
- Loading branch information
Showing
14 changed files
with
496 additions
and
383 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
use crate::{InferenceConfig, LayerStorage, Storage, Weight}; | ||
use common::{bf16, f16, Blob}; | ||
use tensor::{DataType, Tensor, Ty}; | ||
|
||
impl Storage { | ||
pub fn cast(self, dt: DataType) -> Self { | ||
if self.config.dt == dt { | ||
return self; | ||
} | ||
Self { | ||
config: InferenceConfig { dt, ..self.config }, | ||
embed_tokens: cast(self.embed_tokens, dt), | ||
layers: self | ||
.layers | ||
.into_iter() | ||
.map(|l| LayerStorage { | ||
att_layernorm: cast(l.att_layernorm, dt), | ||
att_qkv: cast(l.att_qkv, dt), | ||
att_o: cast(l.att_o, dt), | ||
mlp_layernorm: cast(l.mlp_layernorm, dt), | ||
mlp_gate_up: cast(l.mlp_gate_up, dt), | ||
mlp_down: cast(l.mlp_down, dt), | ||
}) | ||
.collect(), | ||
lm_layernorm: cast(self.lm_layernorm, dt), | ||
lm_head: cast(self.lm_head, dt), | ||
} | ||
} | ||
} | ||
|
||
fn cast(src: Tensor<Weight>, dt: DataType) -> Tensor<Weight> { | ||
match (src.data_type(), dt) { | ||
(DataType::F16, DataType::BF16) => typed(src, |x: &f16| bf16::from_f32(x.to_f32())), | ||
(DataType::F16, DataType::F32) => typed(src, |x: &f16| x.to_f32()), | ||
(DataType::BF16, DataType::F16) => typed(src, |x: &bf16| f16::from_f32(x.to_f32())), | ||
(DataType::BF16, DataType::F32) => typed(src, |x: &bf16| x.to_f32()), | ||
(DataType::F32, DataType::F16) => typed(src, |x: &f32| f16::from_f32(*x)), | ||
(DataType::F32, DataType::BF16) => typed(src, |x: &f32| bf16::from_f32(*x)), | ||
_ => todo!(), | ||
} | ||
} | ||
|
||
fn typed<T: Ty + Sync, U: Ty + Send>( | ||
src: Tensor<Weight>, | ||
cast: impl Fn(&T) -> U + Sync, | ||
) -> Tensor<Weight> { | ||
use rayon::iter::*; | ||
use tensor::{reslice, reslice_mut}; | ||
|
||
assert_eq!(src.data_type(), T::DATA_TYPE); | ||
let mut ans = Tensor::alloc(U::DATA_TYPE, src.shape(), Blob::new); | ||
|
||
reslice(src.physical()) | ||
.par_iter() | ||
.zip(reslice_mut(ans.physical_mut())) | ||
.for_each(|(src, dst)| *dst = cast(src)); | ||
|
||
ans.map_physical(|b| b.into()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
use common::utok; | ||
use tensor::DataType; | ||
|
||
#[derive(serde::Serialize, serde::Deserialize, Debug)] | ||
pub(crate) struct ConfigJson { | ||
pub bos_token_id: utok, | ||
pub eos_token_id: utok, | ||
pub hidden_size: usize, | ||
pub intermediate_size: usize, | ||
pub max_position_embeddings: usize, | ||
pub num_attention_heads: usize, | ||
pub num_hidden_layers: usize, | ||
pub num_key_value_heads: usize, | ||
pub vocab_size: usize, | ||
#[serde(default = "default_rms_norm_eps")] | ||
pub rms_norm_eps: f32, | ||
#[serde(default = "default_rope_theta")] | ||
pub rope_theta: f32, | ||
pub torch_dtype: DataType, | ||
} | ||
|
||
#[inline(always)] | ||
const fn default_rms_norm_eps() -> f32 { | ||
1e-5 | ||
} | ||
|
||
#[inline(always)] | ||
const fn default_rope_theta() -> f32 { | ||
1e4 | ||
} | ||
|
||
macro_rules! convert { | ||
(Dtype: $dtype:expr) => {{ | ||
use common::safe_tensors::Dtype; | ||
use tensor::DataType; | ||
|
||
match $dtype { | ||
Dtype::BOOL => DataType::Bool, | ||
Dtype::I8 => DataType::I8, | ||
Dtype::I16 => DataType::I16, | ||
Dtype::I32 => DataType::I32, | ||
Dtype::I64 => DataType::I64, | ||
Dtype::U8 => DataType::U8, | ||
Dtype::U16 => DataType::U16, | ||
Dtype::U32 => DataType::U32, | ||
Dtype::U64 => DataType::U64, | ||
Dtype::F16 => DataType::F16, | ||
Dtype::BF16 => DataType::BF16, | ||
Dtype::F32 => DataType::F32, | ||
Dtype::F64 => DataType::F64, | ||
_ => unreachable!(), | ||
} | ||
}}; | ||
|
||
(DataType: $data_type:expr) => {{ | ||
use common::safe_tensors::Dtype; | ||
use tensor::DataType; | ||
|
||
match $data_type { | ||
DataType::Bool => Dtype::BOOL, | ||
DataType::I8 => Dtype::I8, | ||
DataType::I16 => Dtype::I16, | ||
DataType::I32 => Dtype::I32, | ||
DataType::I64 => Dtype::I64, | ||
DataType::U8 => Dtype::U8, | ||
DataType::U16 => Dtype::U16, | ||
DataType::U32 => Dtype::U32, | ||
DataType::U64 => Dtype::U64, | ||
DataType::F16 => Dtype::F16, | ||
DataType::BF16 => Dtype::BF16, | ||
DataType::F32 => Dtype::F32, | ||
DataType::F64 => Dtype::F64, | ||
} | ||
}}; | ||
} | ||
|
||
pub(crate) use convert; |
Oops, something went wrong.