Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test different input sequence lengths for Llama #1070

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions forge/test/mlir/llama/test_llama_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import forge
from forge.verify.verify import verify
from test.mlir.llama.utils.utils import load_model


Expand Down Expand Up @@ -124,3 +125,26 @@ def test_llama_inference_cache_cpu(model_path):
# Generated text
generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
print(generated_text)


@pytest.mark.parametrize("model_path", ["openlm-research/open_llama_3b", "meta-llama/Llama-3.2-1B"])
@pytest.mark.parametrize("seq_len", [2048, 512, 128])
def test_llama_input_sequence_lengths(model_path, seq_len):
if model_path == "openlm-research/open_llama_3b" and seq_len == 2048:
pytest.skip("ValueError: Data mismatch for openlm-research/open_llama_3b - sequence length of 2048")
# Load Model and Tokenizer
framework_model, tokenizer = load_model(model_path, num_hidden_layers=1)

# Adjust tokenizer for max sequence length padding
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenizer.model_max_length = seq_len

prompt = "Q: What is the largest animal?\nA:"
input_ids = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
input_ids = input_ids.to(torch.int32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this required? What is the default type for input IDs?

Do we expect that embedding input will always be int-based? If yes, maybe we should have a pass that will encompass this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default type is int64 and we need to cast it due to following issue #952

Yep, embedding inputs are int-based (indices in the vocabulary), but I am not sure what you mean about another pass.


# Compile the model and run fwd pass
compiled_model = forge.compile(framework_model, input_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to test out bwd compile/run as well?

One general question, is there a clean way to test a backward part of a graph in isolation? For example, our compile should return compiled context that contains information about each compiled component (e.g. fwd, bwd, loss, etc.).

Therefore, is there a clean way to just call the bwd part of the graph with random inputs, without a need to run the forward part, and initialize the loss and optimizer part of the training workflow?

Note: this is not a requirement for this PR, just a general question that can be useful here as well. I.e. can we have granular tests that target specific functionality, rather than the whole workflow (only the bwd part of the model). I see this as especially useful for bwd generallity push in the future. cc @vladimirjovanovicTT

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a must-have functionality as part of our training generality/BFS effort.
Let's discuss the implementation details offline.


verify([input_ids], framework_model, compiled_model)
1 change: 1 addition & 0 deletions forge/test/mlir/llama/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def load_model(model_path="openlm-research/open_llama_3b", **kwargs):
config.use_cache = kwargs.get("use_cache", False)
config.output_attentions = kwargs.get("output_attentions", False)
config.output_hidden_states = kwargs.get("output_hidden_states", False)
config.num_hidden_layers = kwargs.get("num_hidden_layers", 26)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this intentional?

Any specific reasons for updating original number of hidden layers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's per our discussion in the last sync. Running llama with all layers takes quite some time and since this is not e2e/demo test, I thought it makes sense to speed it up by using a single layer.


# Load the model
framework_model = LlamaForCausalLM.from_pretrained(model_path, device_map="auto", config=config)
Expand Down
Loading