-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Compute bounds for Index scalars in lowered kernel #3850
base: main
Are you sure you want to change the base?
Conversation
Need to figure out where to put this. I think we should not concern ourselves with index type during lowering at all, and only do this afterward.
Review updated until commit b861294 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
This extends #3599 by also computing the minimal dtype required by the expressions in the lowered kernel. Like in #3599, we cast from
nvfuser_index_t
toint32_t
when passing coords to the TMA expression. However, unlike #3599 we actually verify that this is safe to do by checking the bounds of the inputs to those casts. This way we can safely use 64-bit indexing with TMA and know that we will not get silently incorrect results. Also, we will more commonly use 32-bit indexing because with TMA we often do not have extremely large values for index variables since TMA allows us to do multi-dimensional indexing.Fixes #3601
TODO: add a few tests