Skip to content

Commit

Permalink
feat: 判题
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 17, 2025
1 parent 99d583d commit 971c1c1
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions test-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,27 +145,35 @@ pub fn test_infer(
let maybe_file = Path::new(prompt);
if maybe_file.is_file() {
let file = std::fs::read_to_string(maybe_file).unwrap();
let mut correct = 0;
let mut total = 0;
for line in file.lines() {
let line = serde_json::from_str::<String>(line).unwrap();
let prompt = format!("<s>{line}\n");
let line =
serde_json::from_str::<serde_json::Map<String, serde_json::Value>>(line).unwrap();
let serde_json::Value::String(prompt) = &line["origin_prompt"] else {
unreachable!()
};
let serde_json::Value::String(gold) = &line["gold"] else {
unreachable!()
};
let prompt = format!("<s>{prompt}\n");

// print_now!("{prompt}");

let mut tokens = tokenizer.encode(&prompt);
let mut pos = 0;
for _ in 0..max_steps {
let next = lm(&tokens, pos);
let tokens = tokenizer.encode(&prompt);
let ans = tokenizer.decode(lm(&tokens, 0));

pos += tokens.len();
if next == eos {
break;
}

let piece = tokenizer.decode(next);
print_now!("{piece}");
tokens = vec![next];
let result = gold.as_str() == ans;
if result {
correct += 1
}
println!()
total += 1;

println!(
"({correct:>4}/{total:<4} {:6.2}%) {} {gold} {ans}",
100. * (correct as f64 / total as f64),
if result { "✔" } else { "✘" },
)
}
return;
}
Expand Down

0 comments on commit 971c1c1

Please sign in to comment.