From 21813af73b51eafb0f8b8b1d28f5d697d37f63ba Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 11 Mar 2024 15:23:51 +0800 Subject: [PATCH] =?UTF-8?q?fix(transformer):=20=E4=B8=BA=20config.json=20?= =?UTF-8?q?=E8=A7=A3=E6=9E=90=E6=8F=90=E4=BE=9B=E5=BF=85=E8=A6=81=E7=9A=84?= =?UTF-8?q?=E9=BB=98=E8=AE=A4=E5=80=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- Cargo.lock | 14 ++++++++------ transformer/src/parameters/mod.rs | 12 ++++++++++++ xtask/src/service/channel.rs | 2 +- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4e21c76b..6d86151c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -120,9 +120,9 @@ dependencies = [ [[package]] name = "bytemuck_derive" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" +checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" dependencies = [ "proc-macro2", "quote", @@ -163,9 +163,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.1" +version = "4.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c918d541ef2913577a0f9566e9ce27cb35b6df072075769e0b26cb5a554520da" +checksum = "b230ab84b0ffdf890d5a10abdbc8b83ae1c4918275daea1ab8801f71536b2651" dependencies = [ "clap_builder", "clap_derive", @@ -173,9 +173,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.1" +version = "4.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f3e7391dad68afb0c2ede1bf619f579a3dc9c2ec67f089baa397123a2f3d1eb" +checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" dependencies = [ "anstream", "anstyle", @@ -255,6 +255,7 @@ checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" [[package]] name = "cublas" version = "0.1.0" +source = "git+https://github.com/YdrMaster/cuda-bench#7d3f9f91cf35f8dd190aa40362c50cbbbb438846" dependencies = [ "bindgen", "cuda", @@ -264,6 +265,7 @@ dependencies = [ [[package]] name = "cuda" version = "0.1.0" +source = "git+https://github.com/YdrMaster/cuda-bench#7d3f9f91cf35f8dd190aa40362c50cbbbb438846" dependencies = [ "bindgen", "find_cuda_helper", diff --git a/transformer/src/parameters/mod.rs b/transformer/src/parameters/mod.rs index 359f53b1..7999df62 100644 --- a/transformer/src/parameters/mod.rs +++ b/transformer/src/parameters/mod.rs @@ -110,11 +110,23 @@ struct ConfigJson { pub num_hidden_layers: usize, pub num_key_value_heads: usize, pub vocab_size: usize, + #[serde(default = "default_rms_norm_eps")] pub rms_norm_eps: f32, + #[serde(default = "default_rope_theta")] pub rope_theta: f32, pub torch_dtype: DataType, } +#[inline(always)] +const fn default_rms_norm_eps() -> f32 { + 1e-5 +} + +#[inline(always)] +const fn default_rope_theta() -> f32 { + 1e4 +} + impl From<&dyn Llama2> for ConfigJson { fn from(model: &dyn Llama2) -> Self { Self { diff --git a/xtask/src/service/channel.rs b/xtask/src/service/channel.rs index e614a383..0bf2945b 100644 --- a/xtask/src/service/channel.rs +++ b/xtask/src/service/channel.rs @@ -21,7 +21,7 @@ pub(super) enum ReceiveError { #[derive(Debug)] pub(super) enum SendError { - Io(std::io::Error), + // Io(std::io::Error), Json(serde_json::Error), }