Skip to content

Commit

Permalink
feat(xtask): 实现模型转换应用程序
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 16, 2024
1 parent 111e0bd commit 4818837
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[alias]
xtask = "run --package xtask --release --"
cast = "xtask cast"
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[workspace]
members = ["common", "model-parameters", "tokenizer", "transformer-cpu"]
members = ["common", "model-parameters", "tokenizer", "transformer-cpu", "xtask"]
resolver = "2"
12 changes: 12 additions & 0 deletions model-parameters/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ mod save;
#[macro_use]
extern crate log;

use std::fmt;

use common::utok;

pub use data_type::DataType;
Expand Down Expand Up @@ -61,6 +63,7 @@ pub trait Llama2 {
}

pub use memory::{Memory, SafeTensorError};
use serde::Serialize;

#[derive(serde::Serialize, serde::Deserialize, Debug)]
struct ConfigJson {
Expand All @@ -73,11 +76,20 @@ struct ConfigJson {
pub num_hidden_layers: usize,
pub num_key_value_heads: usize,
pub vocab_size: usize,
#[serde(serialize_with = "serialize_float")]
pub rms_norm_eps: f32,
#[serde(serialize_with = "serialize_float")]
pub rope_theta: f32,
pub torch_dtype: DataType,
}

fn serialize_float<S>(val: &impl fmt::LowerExp, s: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
format!("{val:e}").serialize(s)
}

struct LayerParamsOffset {
input_layernorm: usize,
self_attn_q_proj: usize,
Expand Down
11 changes: 11 additions & 0 deletions xtask/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[package]
name = "xtask"
version = "0.0.0"
edition = "2021"
authors = ["YdrMaster <[email protected]>"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
model-parameters = { path = "../model-parameters" }
clap = { version = "4.5", features = ["derive"] }
74 changes: 74 additions & 0 deletions xtask/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use clap::Parser;
use model_parameters::{save, DataType, Llama2, Memory};
use std::{fs, path::PathBuf, time::Instant};

#[macro_use]
extern crate clap;

fn main() {
use Commands::*;
match Cli::parse().command {
Cast(args) => args.cast(),
}
}

#[derive(Parser)]
#[clap(name = "transformer-utils")]
#[clap(version, about, long_about = None)]
struct Cli {
#[clap(subcommand)]
command: Commands,
}

#[derive(Subcommand)]
enum Commands {
/// Cast model
Cast(CastArgs),
}

#[derive(Args, Default)]
struct CastArgs {
/// Original model directory.
#[clap(short, long)]
model: String,
/// Target model directory.
#[clap(short, long)]
target: Option<String>,
/// Target model type.
#[clap(short, long)]
dt: Option<String>,
}

impl CastArgs {
fn cast(self) {
let ty = match self.dt.as_ref().map(String::as_str) {
Some("f32") | Some("float") | Some("float32") | None => DataType::F32,
Some("f16") | Some("half") | Some("float16") => DataType::F16,
Some("bf16") | Some("bfloat16") => DataType::BF16,
Some(ty) => panic!("Unknown data type: \"{ty}\""),
};
let model_dir = PathBuf::from(self.model);
let time = Instant::now();
let model = Memory::load_safetensors(&model_dir).unwrap();
println!("load model ... {:?}", time.elapsed());
if model.data_type() == ty {
println!("Model already has target data type");
return;
}

let target = self.target.map(PathBuf::from).unwrap_or_else(|| {
model_dir.parent().unwrap().join(format!(
"{}_{ty:?}",
model_dir.file_name().unwrap().to_str().unwrap()
))
});
fs::create_dir_all(&target).unwrap();
let t0 = Instant::now();
let model = Memory::cast(&model, ty);
let t1 = Instant::now();
println!("cast data type ... {:?}", t1 - t0);
save(&model, target).unwrap();
let t2 = Instant::now();
println!("save model ... {:?}", t2 - t1);
}
}

0 comments on commit 4818837

Please sign in to comment.