Skip to content

Commit

Permalink
fix(transformer): 为 config.json 解析提供必要的默认值
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 11, 2024
1 parent ac1cbdc commit 21813af
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
14 changes: 8 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions transformer/src/parameters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,23 @@ struct ConfigJson {
pub num_hidden_layers: usize,
pub num_key_value_heads: usize,
pub vocab_size: usize,
#[serde(default = "default_rms_norm_eps")]
pub rms_norm_eps: f32,
#[serde(default = "default_rope_theta")]
pub rope_theta: f32,
pub torch_dtype: DataType,
}

#[inline(always)]
const fn default_rms_norm_eps() -> f32 {
1e-5
}

#[inline(always)]
const fn default_rope_theta() -> f32 {
1e4
}

impl From<&dyn Llama2> for ConfigJson {
fn from(model: &dyn Llama2) -> Self {
Self {
Expand Down
2 changes: 1 addition & 1 deletion xtask/src/service/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub(super) enum ReceiveError {

#[derive(Debug)]
pub(super) enum SendError {
Io(std::io::Error),
// Io(std::io::Error),
Json(serde_json::Error),
}

Expand Down

0 comments on commit 21813af

Please sign in to comment.