-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QST] Gather/Scatter in cute/cutlass 3 #1330
Comments
Sounds like you want a grouped gemm that supports gather/scatter? Have you taken a look at example 52 for inspiration? Happy to help with the design, but using CuTe is fairly natural with the extensions in Hopper gather/scatter GEMM |
Sadly I only have access to SM89 so I haven't looked into the Hooper examples too much. To be honest I'm still rather new to GPU computing and I'm kinda having a hard time understanding how cute/cutlass works. Especially the tiny sizes of my matrices don't seem to match the supported layouts. Thinking about it a grouped GEMM with scatter/gather which runs on SM89 would pretty much allow me to do everything my trainer needs except for the first layer which is sparse. Would you be willing to discuss my network and help me make a plan how to tackle it via discord? |
I'm on vacation for a couple weeks but can help asynchronously on this thread. The concepts presented in example 52 are applicable pretty much 1:1 over here. I'm specifically referring to how the gather tensors layouts are transcribed. You can use the same custom indexed gather strides used in that example for your usecase. I recommend starting with a single file cute kernel like the cute tutorial |
does https://github.com/NVIDIA/cutlass/blob/main/examples/24_gemm_grouped/gemm_grouped.cu meet your need to run multiple GEMMs in parallel? does https://github.com/NVIDIA/cutlass/blob/main/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu meet your need to fuse gather/scatter into a GEMM? If yes to both, you could just merge ex36 into ex24 which is fairly easy. gather/scatter adds these three templates into the highest level templates. and these three arguments into top level arguments. So, what you need to do is to add these into group gemm interface. group gemm and gather/scatter use the same underlying gemm implementation. so you just need to do some plumbing in the top levels device and kernel levels. |
I'm also interested in combining grouped gemms with gather / scatter fusion for Ampere (non-Hopper) architectures. I see that both examples use the same underlying kernel GemmUniversal. Other than adapting the top-level arguments and templates per your previous answer, what other changes do I need to make to the underlying code in GemmGrouped such that it properly constructs, initializes and passes the args for the individual gather / scatter kernels? I.e., other than instantiating the underlying |
@jackkosaian ^ ^ |
It sounds like you've already figured out where new You'll also need to add them to the kernel's Then, you need to augment the (1) can be achieved by indexing into the gather/scatter pointers in (2) can be done by following the pattern of use of the gather/scatter indices that is used by I hope this helps! |
Thanks for the clear explanation! For the For example, if I'm using In the case that the gathered |
Any sort of padding would need to be handled externally to |
Quick questions: When filling a tensor with cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk());
cutlass::reference::host::BlockFillSequential(
tensor_a.host_data(), problem_size.m() * problem_size.k()); For Guessing that I'm doing something wrong here... |
there is no 2049 in fp16. you are doing correct thing. |
Thanks @hwu36 When I repeat the above using |
if you check the result returned by |
in fp16, 2048+1=2048 <- fp16(2049)=2048 <- |
Are there any examples of How would one implement the above (combining |
We do not have examples of this.
My suggestion would be to try to take a look at the CUTLASS 3 examples for gather/scatter and grouped GEMM (each of which currently target Hopper). You could consider adapting these to use SM80 CUTLASS 3 mainloops (similar to unit tests like this one). Note, however, that GEMMs produced via the CUTLASS 3 API for CC < 90 are not currently as well optimized as those produced via the CUTLASS 2 API |
That said, do not be discouraged. carefully crafted data and thread layouts can hit 95% peak perf on 3.x mainloop as well :) |
@jeromeku has you question been answered? |
How would I implement a More specifically:
Seems like I'd have to customize |
#1398 (comment) from @apuaaChen can help you with the vectorized scaling part. you still need to reuse the gather scatter logic in place. |
Hi! We just released CUTLASS 3.5 and it contains an example of CUTLASS 3.x based gather/scatter convolution kernel |
@jeromeku with the release of v3.5, can we close this issue? |
@mnicely yes, thanks. |
Hi everybody,
I'm currently trying to writing a trainer for a very small an oddly shaped network which requires a lot of gather/scatter. E.g. one layer looks like this:
C = A0 * B0 + A1 * B1 ...
where the A matrices are 32x8 and the B matrices are 8xN. There are 36 possible A matrices so I could load them all to shared memory once. The B matrices need to be gathered along the col (six in total). So I probably need to gather them all using cp.async then do a single 32x48 * 48xN GEMM (F32F16F16F32) I guess? What's the best way to approach this in cute? Could this be done in cutlass?
Help would really be appreciated.
The text was updated successfully, but these errors were encountered: