From 103f42c1050bdffff665ef78f0295792b49244f4 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 17 Jan 2025 13:34:41 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=A4=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- models/llama/nvidia-gpu/src/lib.rs | 1 + test-utils/src/lib.rs | 41 +++++++++++++++++++++++------- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/models/llama/nvidia-gpu/src/lib.rs b/models/llama/nvidia-gpu/src/lib.rs index 11222dd..8d0439b 100644 --- a/models/llama/nvidia-gpu/src/lib.rs +++ b/models/llama/nvidia-gpu/src/lib.rs @@ -244,6 +244,7 @@ impl<'blk> Weights<'blk> { load! { attn_norm attn_qkv + attn_qkv_bias attn_o ffn_norm ffn_gate_inp diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 07cee67..170c302 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -145,27 +145,50 @@ pub fn test_infer( let maybe_file = Path::new(prompt); if maybe_file.is_file() { let file = std::fs::read_to_string(maybe_file).unwrap(); + let mut correct = 0; + let mut total = 0; for line in file.lines() { - let line = serde_json::from_str::(line).unwrap(); - let prompt = format!("{line}\n"); - - // print_now!("{prompt}"); + let line = + serde_json::from_str::>(line).unwrap(); + let serde_json::Value::String(prompt) = &line["origin_prompt"] else { + unreachable!() + }; + let serde_json::Value::String(gold) = &line["gold"] else { + unreachable!() + }; let mut tokens = tokenizer.encode(&prompt); + let mut ans = String::new(); let mut pos = 0; + let mut result = false; for _ in 0..max_steps { let next = lm(&tokens, pos); - - pos += tokens.len(); if next == eos { break; } - let piece = tokenizer.decode(next); - print_now!("{piece}"); + pos += tokens.len(); tokens = vec![next]; + + let piece = tokenizer.decode(next); + ans.push_str(&piece); + + if let Some(x) = piece.chars().find(|c| matches!(c, 'A'..='D')) { + result = x == gold.chars().next().unwrap(); + break; + } } - println!() + + if result { + correct += 1 + } + total += 1; + + println!( + "({correct:>4}/{total:<4} {:6.2}%) {} {gold} {ans}", + 100. * (correct as f64 / total as f64), + if result { "✔" } else { "✘" }, + ) } return; }