-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mx_fp8_bf16 kernel #1637
base: drisspg/stack/30
Are you sure you want to change the base?
Add mx_fp8_bf16 kernel #1637
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1637
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New FailureAs of commit 7473aca with merge base b2fb664 (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #1637, branch: drisspg/stack/31
3b57cd9
to
ae51147
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! if CI is green - looks good! I think this should have at least one numerical test though. Can be a follow-up PR if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool!
stack-info: PR: #1637, branch: drisspg/stack/31
ae51147
to
1e3d2dd
Compare
stack-info: PR: #1637, branch: drisspg/stack/31
1e3d2dd
to
8d90a66
Compare
stack-info: PR: #1637, branch: drisspg/stack/31
8d90a66
to
0646800
Compare
stack-info: PR: #1637, branch: drisspg/stack/31
0646800
to
7473aca
Compare
Stacked PRs:
Add mx_fp8_bf16 kernel
Will flesh out more but this moves over the kernel from here: https://github.com/drisspg/driss_torch/blob/2813322f0b0f9a0f0fc8d382090ad0aaecf3468a/src/mx_fp8_bf16.cu#L162
This does fp8xfp8 w/ E8m0 scales and group_size hard coded to 32. The format for the scales is the same as that for cublasLT. I have created a pytorch function that converts the [n_rows, n_cols//32] scales into the expected format:
https://github.com/drisspg/transformer_nuggets/blob/382cb0f19a5f615827174289b8ef552419d51fea/transformer_nuggets/mx/to_blocked.py#L11
This was surprisingly hard fought and would not have been possible w/ @albanD 😊
This allows this PR: #1625 to not have any dependencies on PT core updates while we add the required dtypes and bindings to cublas: pytorch/pytorch#145562
Follow up
Config needs more tuning