diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 07cee67..1350695 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -145,27 +145,35 @@ pub fn test_infer( let maybe_file = Path::new(prompt); if maybe_file.is_file() { let file = std::fs::read_to_string(maybe_file).unwrap(); + let mut correct = 0; + let mut total = 0; for line in file.lines() { - let line = serde_json::from_str::(line).unwrap(); - let prompt = format!("{line}\n"); + let line = + serde_json::from_str::>(line).unwrap(); + let serde_json::Value::String(prompt) = &line["origin_prompt"] else { + unreachable!() + }; + let serde_json::Value::String(gold) = &line["gold"] else { + unreachable!() + }; + let prompt = format!("{prompt}\n"); // print_now!("{prompt}"); - let mut tokens = tokenizer.encode(&prompt); - let mut pos = 0; - for _ in 0..max_steps { - let next = lm(&tokens, pos); + let tokens = tokenizer.encode(&prompt); + let ans = tokenizer.decode(lm(&tokens, 0)); - pos += tokens.len(); - if next == eos { - break; - } - - let piece = tokenizer.decode(next); - print_now!("{piece}"); - tokens = vec![next]; + let result = gold.as_str() == ans; + if result { + correct += 1 } - println!() + total += 1; + + println!( + "({correct:>4}/{total:<4} {:6.2}%) {} {gold} {ans}", + 100. * (correct as f64 / total as f64), + if result { "✔" } else { "✘" }, + ) } return; }