-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama 3.2 1B Instruct on TPU v4, bumping transformers to 4.45.2 #109
base: main
Are you sure you want to change the base?
Conversation
Note that the two new test are just manual test, not pytests. The rope implementation is unvalidated - we just pray and are happy that it still generates tokens XD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will wait for the other contribution to be merged before merging this one, but thank you for contributing! Can you confirm the models you have tested with your changes?
tests/akg.py
Outdated
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int() | ||
return next_token_id | ||
|
||
def _test_distributed_model_generation(model_id, max_new_tokens=20): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for tests, please create one test similar to tests/test_distributed_model.py
(or modify the existing one). To launch it, you can use pytest: python -m pytest -sv /path/to/test_mytest.py::test_my_test_function
.
Added llama3 rope_type implementation and changed default model to Llama 3.2 1B Instruct.
Create an adaptation of the HF transformer's llama3 rope_type implementation in modeling_llama.py.
Updated the dependency to the current transformer library version 4.45.2.
Added more logging to distributed_model.py as the TPU v4-8 vms love to hang at random places when running this code.
Fixes #80
Before submitting