Skip to content

Commit

Permalink
fix: change some naming
Browse files Browse the repository at this point in the history
  • Loading branch information
PanZezhong1725 committed Nov 21, 2024
1 parent e2bd9b1 commit 370adb0
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

### 1. 算子:SiLU函数(10分)

请在`src/operators.rs`中实现SiLU算子,其公式为:
请在`src/operators.rs`中实现SwiGLU算子,其公式为:

$$
y=silu(x) × y
Expand Down Expand Up @@ -85,8 +85,8 @@ $$
hidden = rms_norm(residual)
gate = hidden @ gate_weight.T
up = hidden @ up_weight.T
itermediate = gate * sigmoid(gate) * up ## silu
output = itermediate @ down_weight.T
act = gate * sigmoid(gate) * up ## SwiGLU
output = act @ down_weight.T
residual = output + residual
```

Expand Down Expand Up @@ -149,9 +149,9 @@ V = cat(V_cache, V)
### 以下是你需要实现的部分
score = Q @ K.T / sqrt(dim)
attn = softmax(score)
x = attn @ V
x = x @ O_weight.T
residual = x + residual
attn_V = attn @ V
out = attn_V @ O_weight.T
residual = out + residual
```

Self-Attention的调试是很困难的。这里推荐大家使用pytorch来辅助调试。各位可以用transformers库(使用llama模型代码)来加载模型并运行,逐层检查中间张量结果。
Expand Down
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ fn main() {
let output_ids = llama.generate(
input_ids,
500,
0.9,
4,
0.8,
30,
1.,
);
println!("{}", tokenizer.decode(&output_ids, true).unwrap());
Expand Down
6 changes: 3 additions & 3 deletions src/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ pub fn rms_norm(y: &mut Tensor<f32>, x: &Tensor<f32>, w: &Tensor<f32>, epsilon:
todo!("实现 rms_norm,计算前做一些必要的检查会帮助你后续调试")
}

// y = sigmoid(x) * x * y
// y = silu(x) * y
// hint: this is an element-wise operation
pub fn silu(y: &mut Tensor<f32>, x: &Tensor<f32>) {
pub fn swiglu(y: &mut Tensor<f32>, x: &Tensor<f32>) {
// let len = y.size();
// assert!(len == x.size());

Expand Down Expand Up @@ -176,7 +176,7 @@ pub fn random_sample(x: &Tensor<f32>, top_p: f32, top_k: u32, temperature: f32)
fn test_silu() {
let mut y = Tensor::<f32>::new(vec![2., 3., 4.], &vec![1, 3]);
let x = Tensor::<f32>::new(vec![1., 2., 3.], &vec![1, 3]);
silu(&mut y, &x);
swiglu(&mut y, &x);
assert!(y.close_to(
&Tensor::<f32>::new(vec![1.4621172, 5.2847824, 11.43089], &vec![1, 3]),
1e-3
Expand Down

0 comments on commit 370adb0

Please sign in to comment.