diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index a698be1..e459c2c 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -13,3 +13,4 @@ gguf.workspace = true tensor.workspace = true env_logger.workspace = true cli-table = "0.4.9" +serde_json = "1.0" diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index b52d784..07cee67 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -142,6 +142,34 @@ pub fn test_infer( }}; } + let maybe_file = Path::new(prompt); + if maybe_file.is_file() { + let file = std::fs::read_to_string(maybe_file).unwrap(); + for line in file.lines() { + let line = serde_json::from_str::(line).unwrap(); + let prompt = format!("{line}\n"); + + // print_now!("{prompt}"); + + let mut tokens = tokenizer.encode(&prompt); + let mut pos = 0; + for _ in 0..max_steps { + let next = lm(&tokens, pos); + + pos += tokens.len(); + if next == eos { + break; + } + + let piece = tokenizer.decode(next); + print_now!("{piece}"); + tokens = vec![next]; + } + println!() + } + return; + } + print_now!("{prompt}"); let mut tokens = tokenizer.encode(prompt);