diff --git a/forge/test/mlir/llama/test_llama_inference.py b/forge/test/mlir/llama/test_llama_inference.py index 38ef7806e..f70fc6bf3 100644 --- a/forge/test/mlir/llama/test_llama_inference.py +++ b/forge/test/mlir/llama/test_llama_inference.py @@ -6,6 +6,7 @@ import pytest import forge +from forge.verify.verify import verify from test.mlir.llama.utils.utils import load_model @@ -149,4 +150,5 @@ def test_llama_input_sequence_lengths(model_path, seq_len): # Compile the model and run fwd pass compiled_model = forge.compile(framework_model, input_ids) - logits = compiled_model(input_ids) + + verify([input_ids], framework_model, compiled_model)