From d37665af3725b6d231426f4da599fcc0c825ddf8 Mon Sep 17 00:00:00 2001 From: William Arnold Date: Tue, 17 Dec 2024 22:08:09 -0800 Subject: [PATCH] Fix CI --- src/levanter/main/viz_logprobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index e0aa1596d..682a29f6d 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -74,7 +74,7 @@ def main(config: VizGpt2Config): def compute_log_probs(model: LmHeadModel, example: LmExample): model = inference_mode(model, True) model = mp.cast_to_compute(model) - logprobs = compute_next_token_loss(model, example, reduction=None) + logprobs, where, _ = compute_next_token_loss(model, example) # roll forward to get the loss for each predicted token logprobs = hax.roll(logprobs, 1, Pos) return logprobs.rearrange((EvalBatch, Pos)).array