diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..5c8ec081 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,3 @@ +[alias] +xtask = "run --package xtask --release --" +cast = "xtask cast" diff --git a/Cargo.toml b/Cargo.toml index 719341da..88abae36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,3 @@ [workspace] -members = ["common", "model-parameters", "tokenizer", "transformer-cpu"] +members = ["common", "model-parameters", "tokenizer", "transformer-cpu", "xtask"] resolver = "2" diff --git a/model-parameters/src/lib.rs b/model-parameters/src/lib.rs index bb67f49b..9896ad45 100644 --- a/model-parameters/src/lib.rs +++ b/model-parameters/src/lib.rs @@ -5,6 +5,8 @@ mod save; #[macro_use] extern crate log; +use std::fmt; + use common::utok; pub use data_type::DataType; @@ -61,6 +63,7 @@ pub trait Llama2 { } pub use memory::{Memory, SafeTensorError}; +use serde::Serialize; #[derive(serde::Serialize, serde::Deserialize, Debug)] struct ConfigJson { @@ -73,11 +76,20 @@ struct ConfigJson { pub num_hidden_layers: usize, pub num_key_value_heads: usize, pub vocab_size: usize, + #[serde(serialize_with = "serialize_float")] pub rms_norm_eps: f32, + #[serde(serialize_with = "serialize_float")] pub rope_theta: f32, pub torch_dtype: DataType, } +fn serialize_float(val: &impl fmt::LowerExp, s: S) -> Result +where + S: serde::Serializer, +{ + format!("{val:e}").serialize(s) +} + struct LayerParamsOffset { input_layernorm: usize, self_attn_q_proj: usize, diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml new file mode 100644 index 00000000..3add9121 --- /dev/null +++ b/xtask/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "xtask" +version = "0.0.0" +edition = "2021" +authors = ["YdrMaster "] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +model-parameters = { path = "../model-parameters" } +clap = { version = "4.5", features = ["derive"] } diff --git a/xtask/src/main.rs b/xtask/src/main.rs new file mode 100644 index 00000000..1a9f51c0 --- /dev/null +++ b/xtask/src/main.rs @@ -0,0 +1,74 @@ +use clap::Parser; +use model_parameters::{save, DataType, Llama2, Memory}; +use std::{fs, path::PathBuf, time::Instant}; + +#[macro_use] +extern crate clap; + +fn main() { + use Commands::*; + match Cli::parse().command { + Cast(args) => args.cast(), + } +} + +#[derive(Parser)] +#[clap(name = "transformer-utils")] +#[clap(version, about, long_about = None)] +struct Cli { + #[clap(subcommand)] + command: Commands, +} + +#[derive(Subcommand)] +enum Commands { + /// Cast model + Cast(CastArgs), +} + +#[derive(Args, Default)] +struct CastArgs { + /// Original model directory. + #[clap(short, long)] + model: String, + /// Target model directory. + #[clap(short, long)] + target: Option, + /// Target model type. + #[clap(short, long)] + dt: Option, +} + +impl CastArgs { + fn cast(self) { + let ty = match self.dt.as_ref().map(String::as_str) { + Some("f32") | Some("float") | Some("float32") | None => DataType::F32, + Some("f16") | Some("half") | Some("float16") => DataType::F16, + Some("bf16") | Some("bfloat16") => DataType::BF16, + Some(ty) => panic!("Unknown data type: \"{ty}\""), + }; + let model_dir = PathBuf::from(self.model); + let time = Instant::now(); + let model = Memory::load_safetensors(&model_dir).unwrap(); + println!("load model ... {:?}", time.elapsed()); + if model.data_type() == ty { + println!("Model already has target data type"); + return; + } + + let target = self.target.map(PathBuf::from).unwrap_or_else(|| { + model_dir.parent().unwrap().join(format!( + "{}_{ty:?}", + model_dir.file_name().unwrap().to_str().unwrap() + )) + }); + fs::create_dir_all(&target).unwrap(); + let t0 = Instant::now(); + let model = Memory::cast(&model, ty); + let t1 = Instant::now(); + println!("cast data type ... {:?}", t1 - t0); + save(&model, target).unwrap(); + let t2 = Instant::now(); + println!("save model ... {:?}", t2 - t1); + } +}