Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Gather/Scatter in cute/cutlass 3 #1330

Closed
akamiru opened this issue Feb 3, 2024 · 23 comments
Closed

[QST] Gather/Scatter in cute/cutlass 3 #1330

akamiru opened this issue Feb 3, 2024 · 23 comments
Labels
question Question

Comments

@akamiru
Copy link

akamiru commented Feb 3, 2024

Hi everybody,

I'm currently trying to writing a trainer for a very small an oddly shaped network which requires a lot of gather/scatter. E.g. one layer looks like this:
C = A0 * B0 + A1 * B1 ...

where the A matrices are 32x8 and the B matrices are 8xN. There are 36 possible A matrices so I could load them all to shared memory once. The B matrices need to be gathered along the col (six in total). So I probably need to gather them all using cp.async then do a single 32x48 * 48xN GEMM (F32F16F16F32) I guess? What's the best way to approach this in cute? Could this be done in cutlass?

Help would really be appreciated.

@thakkarV
Copy link
Collaborator

thakkarV commented Feb 3, 2024

Sounds like you want a grouped gemm that supports gather/scatter? Have you taken a look at example 52 for inspiration? Happy to help with the design, but using CuTe is fairly natural with the extensions in Hopper gather/scatter GEMM

@akamiru
Copy link
Author

akamiru commented Feb 4, 2024

Sadly I only have access to SM89 so I haven't looked into the Hooper examples too much. To be honest I'm still rather new to GPU computing and I'm kinda having a hard time understanding how cute/cutlass works. Especially the tiny sizes of my matrices don't seem to match the supported layouts.

Thinking about it a grouped GEMM with scatter/gather which runs on SM89 would pretty much allow me to do everything my trainer needs except for the first layer which is sparse. Would you be willing to discuss my network and help me make a plan how to tackle it via discord?

@thakkarV
Copy link
Collaborator

thakkarV commented Feb 4, 2024

I'm on vacation for a couple weeks but can help asynchronously on this thread. The concepts presented in example 52 are applicable pretty much 1:1 over here. I'm specifically referring to how the gather tensors layouts are transcribed. You can use the same custom indexed gather strides used in that example for your usecase. I recommend starting with a single file cute kernel like the cute tutorial

@hwu36
Copy link
Collaborator

hwu36 commented Feb 5, 2024

does https://github.com/NVIDIA/cutlass/blob/main/examples/24_gemm_grouped/gemm_grouped.cu meet your need to run multiple GEMMs in parallel?

does https://github.com/NVIDIA/cutlass/blob/main/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu meet your need to fuse gather/scatter into a GEMM?

If yes to both, you could just merge ex36 into ex24 which is fairly easy. gather/scatter adds these three templates into the highest level templates. and these three arguments into top level arguments. So, what you need to do is to add these into group gemm interface.

group gemm and gather/scatter use the same underlying gemm implementation. so you just need to do some plumbing in the top levels device and kernel levels.

@shangz-ai

@jeromeku
Copy link
Contributor

jeromeku commented Feb 15, 2024

@hwu36

I'm also interested in combining grouped gemms with gather / scatter fusion for Ampere (non-Hopper) architectures.

I see that both examples use the same underlying kernel GemmUniversal.

Other than adapting the top-level arguments and templates per your previous answer, what other changes do I need to make to the underlying code in GemmGrouped such that it properly constructs, initializes and passes the args for the individual gather / scatter kernels?

I.e., other than instantiating the underlying gemm kernel for the grouped gemm per above and adding the tensor indices to the grouped gemm arguments struct, what else needs to be done to pass these indices to the gather / scatter gemms? What I'm not clear on is how the top-level GemmGroup arguments interact with the underlying kernels (GemmUniversal).

@hwu36
Copy link
Collaborator

hwu36 commented Feb 15, 2024

@jackkosaian ^ ^

@jackkosaian
Copy link
Contributor

@jeromeku,

It sounds like you've already figured out where new Arguments should be placed: here.

You'll also need to add them to the kernel's Params struct here, similar to how they are added for GemmUniversal::Params here (but noting that you'll have pointers to the gather/scatter index pointers for grouped GEMM since you want one list of indices per problem in the group).

Then, you need to augment the kernel::GemmGrouped::operator() to (1) determine which gather/scatter pointer to use for a given tile being processed by the thread block, and (2) pass these pointers down to lower levels of the kernel.

(1) can be achieved by indexing into the gather/scatter pointers in Params using problem_idx similar to what grouped GEMM does for A/B pointers here.

(2) can be done by following the pattern of use of the gather/scatter indices that is used by GemmUniversal::operator() (e.g., here.)

I hope this helps!

@jeromeku
Copy link
Contributor

@jackkosaian

Thanks for the clear explanation!

For the gather / scatter kernels, if I'm gathering rows of A (and scattering into rows of C / D), are there additional checks that need to be implemented to ensure that each gemm size is valid?

For example, if I'm using TensorOps, the minimum M is 16 as predetermined by the tensor core instruction and thus the warp and tile shapes need to be multiples of 16.

In the case that the gathered M doesn't meet this requirement, would can_implement need to be changed to ensure this condition by, e.g., padding to a minimum block size? Is this implemented in the codebase or are there related examples?

@jackkosaian
Copy link
Contributor

Any sort of padding would need to be handled externally to can_implement. You would need to pad your tensors, problem shapes, etc. before setting them in the Arguments struct.

@jeromeku
Copy link
Contributor

Quick questions:

When filling a tensor with BlockFillSequential as such:

cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
        problem_size.mk()); 
cutlass::reference::host::BlockFillSequential(
        tensor_a.host_data(), problem_size.m() * problem_size.k());

For ElementInputA cutlass::half_t and m = k = 128 this gives a sequence only up to 2048, after which it repeats 2048 for the remaining values.

Guessing that I'm doing something wrong here...

@hwu36
Copy link
Collaborator

hwu36 commented Feb 17, 2024

there is no 2049 in fp16. you are doing correct thing.

@jeromeku
Copy link
Contributor

Thanks @hwu36

When I repeat the above using host::TensorFillSequential, I get an increasing sequence up to 16384 for m = k = 128 as expected, whereas for BlockFillSequential, the sequence stops at 2048 (and repeats 2048 thereafter).

@hwu36
Copy link
Collaborator

hwu36 commented Feb 18, 2024

if you check the result returned by host::TensorFillSequential, you wont see 2049.

@hwu36
Copy link
Collaborator

hwu36 commented Feb 18, 2024

host::TensorFillSequential uses slightly different way to calculate, but after 2048, they both are limited by fp16.

in fp16,

2048+1=2048 <- BlockFillSequential way

fp16(2049)=2048 <- TensorFillSequential way

@jeromeku
Copy link
Contributor

@jackkosaian

Are there any examples of gather / scatter fusion and grouped_gemm specifically for Ampere architectures using Cutlass 3.0+ and CuTe?

How would one implement the above (combining gather / scatter and grouped_gemm) using the 3.0 API as opposed to the legacy 2.0 interface?

@jackkosaian
Copy link
Contributor

jackkosaian commented Feb 22, 2024

Are there any examples of gather / scatter fusion and grouped_gemm specifically for Ampere architectures using Cutlass 3.0+ and CuTe?

We do not have examples of this.

How would one implement the above (combining gather / scatter and grouped_gemm) using the 3.0 API as opposed to the legacy 2.0 interface?

My suggestion would be to try to take a look at the CUTLASS 3 examples for gather/scatter and grouped GEMM (each of which currently target Hopper). You could consider adapting these to use SM80 CUTLASS 3 mainloops (similar to unit tests like this one). Note, however, that GEMMs produced via the CUTLASS 3 API for CC < 90 are not currently as well optimized as those produced via the CUTLASS 2 API

@thakkarV
Copy link
Collaborator

that GEMMs produced via the CUTLASS 3 API for CC < 90 are not currently as well optimized as those produced via the CUTLASS 2 API

That said, do not be discouraged. carefully crafted data and thread layouts can hit 95% peak perf on 3.x mainloop as well :)

@mnicely
Copy link
Collaborator

mnicely commented Feb 22, 2024

@jeromeku has you question been answered?

@jeromeku
Copy link
Contributor

@jackkosaian

How would I implement a gather_scatter with broadcasted scale factor in the epilogue?

More specifically:

  • gather rows from A, scatter to D
  • each scattered D row should be multiplied by its own scale factor, gathered from a vector of such factors

Seems like I'd have to customize default_epilogue_with_broadcast / epilogue_with_broadcast, default_gemm_with_broadcast, gemm_universal_with_broadcast, and gemm_with_fused_epilogue to accept gather / scatter indices?

@hwu36
Copy link
Collaborator

hwu36 commented Mar 20, 2024

#1398 (comment) from @apuaaChen can help you with the vectorized scaling part. you still need to reuse the gather scatter logic in place.

@thakkarV
Copy link
Collaborator

Hi! We just released CUTLASS 3.5 and it contains an example of CUTLASS 3.x based gather/scatter convolution kernel

@mnicely
Copy link
Collaborator

mnicely commented Apr 5, 2024

@jeromeku with the release of v3.5, can we close this issue?

@jeromeku
Copy link
Contributor

jeromeku commented Apr 5, 2024

@mnicely yes, thanks.

@mnicely mnicely closed this as completed Apr 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Question
Projects
None yet
Development

No branches or pull requests

6 participants