Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mx_fp8_bf16 kernel #1637

Open
wants to merge 1 commit into
base: drisspg/stack/30
Choose a base branch
from
Open

Add mx_fp8_bf16 kernel #1637

wants to merge 1 commit into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Jan 29, 2025

Stacked PRs:


Add mx_fp8_bf16 kernel

Will flesh out more but this moves over the kernel from here: https://github.com/drisspg/driss_torch/blob/2813322f0b0f9a0f0fc8d382090ad0aaecf3468a/src/mx_fp8_bf16.cu#L162

This does fp8xfp8 w/ E8m0 scales and group_size hard coded to 32. The format for the scales is the same as that for cublasLT. I have created a pytorch function that converts the [n_rows, n_cols//32] scales into the expected format:
https://github.com/drisspg/transformer_nuggets/blob/382cb0f19a5f615827174289b8ef552419d51fea/transformer_nuggets/mx/to_blocked.py#L11
This was surprisingly hard fought and would not have been possible w/ @albanD 😊

This allows this PR: #1625 to not have any dependencies on PT core updates while we add the required dtypes and bindings to cublas: pytorch/pytorch#145562

Follow up

Config needs more tuning

Copy link

pytorch-bot bot commented Jan 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1637

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure

As of commit 7473aca with merge base b2fb664 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

drisspg added a commit that referenced this pull request Jan 29, 2025
stack-info: PR: #1637, branch: drisspg/stack/31
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 29, 2025
@drisspg drisspg added the topic: new feature Use this tag if this PR adds a new feature label Jan 29, 2025
Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! if CI is green - looks good! I think this should have at least one numerical test though. Can be a follow-up PR if needed.

Copy link

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool!

setup.py Outdated Show resolved Hide resolved
torchao/ops.py Show resolved Hide resolved
@drisspg drisspg changed the base branch from drisspg/stack/30 to main February 3, 2025 21:47
drisspg added a commit that referenced this pull request Feb 3, 2025
stack-info: PR: #1637, branch: drisspg/stack/31
@drisspg drisspg changed the base branch from main to drisspg/stack/30 February 3, 2025 21:47
@drisspg drisspg changed the base branch from drisspg/stack/30 to main February 3, 2025 23:24
drisspg added a commit that referenced this pull request Feb 3, 2025
stack-info: PR: #1637, branch: drisspg/stack/31
@drisspg drisspg changed the base branch from main to drisspg/stack/30 February 3, 2025 23:25
@drisspg drisspg changed the base branch from drisspg/stack/30 to main February 3, 2025 23:30
drisspg added a commit that referenced this pull request Feb 3, 2025
stack-info: PR: #1637, branch: drisspg/stack/31
@drisspg drisspg changed the base branch from main to drisspg/stack/30 February 3, 2025 23:30
@drisspg drisspg changed the base branch from drisspg/stack/30 to main February 4, 2025 19:55
@drisspg drisspg changed the base branch from main to drisspg/stack/30 February 4, 2025 19:55
@drisspg drisspg mentioned this pull request Feb 4, 2025
stack-info: PR: #1637, branch: drisspg/stack/31
@drisspg drisspg changed the base branch from drisspg/stack/30 to main February 4, 2025 19:55
@drisspg drisspg changed the base branch from main to drisspg/stack/30 February 4, 2025 19:57
@drisspg drisspg changed the base branch from drisspg/stack/30 to main February 4, 2025 20:00
@drisspg drisspg changed the base branch from main to drisspg/stack/30 February 4, 2025 20:00
@drisspg drisspg added the mx label Feb 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. mx topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants