Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA]question about Cute supports uint1b_t GeMM #1863

Closed
CalebDu opened this issue Oct 11, 2024 · 4 comments
Closed

[FEA]question about Cute supports uint1b_t GeMM #1863

CalebDu opened this issue Oct 11, 2024 · 4 comments
Labels

Comments

@CalebDu
Copy link
Contributor

CalebDu commented Oct 11, 2024

I want to develop customized fusion GeMM kernel for uint1b_t dtype in Cute. I note that Cute has no official example and test about uint1b_t GeMM and MMA_Trait only instantiate struct MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC> with wrong ABLayout. So I create PRs to fix wrong ABLayout and append more supports for MMA_Traits and MMA_op for uint1b_t in PR #1856.
After then, I start to implement simple GeMM for uint1b_t. But I find Cute doesn't seem to support subbyte pack and unpack like SubbyteReferencein Cutlass. It leads to my customized gemm gets wrong result.
Image

To prove it, I do a simple test as above code, I create a A Cultass tensor with shape (1024, 128) in uint1b_t dype and initialize it with 0x1. Cutlass support pack 8 uint1b_t into 1 uint8_t. so A is allocated 1024*128/8 =16 kB memory. Then I use A.device_data() ptr to build a Cute Tensor. print both A tensor as following. Cutlass A Tensor print correct result. But Cute A tensor print 255 (0x11111111) becaute uint1b_t stored in uint8_t without bit pack and unpack.
Image
Image

I want to know how cute can support uint1b_t GeMM with pack/unpack to get correct result. I can help submit PR.

@CalebDu CalebDu added ? - Needs Triage feature request New feature or request labels Oct 11, 2024
@thakkarV
Copy link
Collaborator

thakkarV commented Oct 11, 2024

It leads to my customized gemm gets wrong result.

Not having sub-byte pack/unpack should not lead to the GEMM getting incorrect results if you are using SM80_16x8x256_S32U1U1S32_TN_XORPOPC because the MMA itself does not require any unpacking of the bits unless your GMEM inputs are not K major.

So are you trying to accept MN major input tensors or is it only the printing that is broken.

@CalebDu
Copy link
Contributor Author

CalebDu commented Oct 11, 2024

A/B Tensor are both K major. I find cute has array_subbyte.hpp to support subbyte dtype. maybe the wrong result not about subbyte packing/unpacking. I need more attempts to locate the problem.

@ccecka
Copy link

ccecka commented Oct 11, 2024

You can create a packed CuTe tensor like this:

Tensor mA = make_tensor(make_gmem_ptr<uint1b_t>(my_ptr), my_layout_of_1b);

which creates a "packed" pointer from my_ptr. Similarly with rmem or smem:

Tensor sA = make_tensor(make_smem_ptr<uint1b_t>(my_s_ptr), my_layout_of_1b);
Tensor rA = make_tensor<uint1b_t>(my_layout_of_1b);

It looks like you're creating it with a raw uint1b_t*? CuTe cannot assume that pointer means "packed" safely (and neither should CUTLASS...). This is the reason why array_subbyte.data() has been removed from CuTe's array_subbyte container (But apparently not CUTLASS's) -- it is dangerous and error-prone.

@CalebDu
Copy link
Contributor Author

CalebDu commented Oct 12, 2024

@ccecka you're right. The root cause is how to callmake_gmem_ptr .
Image
T is uintb_t, ACC_T is int32_t. I use old version to initialize cute Tensor in GeMM kernel and print_tensor show 255(0x11111111) per uint1b_t element with wrong GeMM result.
Following your example as new version. print_tensor show 1(0x1) per uint1b_t element with correct GeMM result.
Image
Image

new version calls recast_ptr correctly.
The CUTE documentation should mention the differences here to avoid other users encountering the same problem
ccecka, thank your help!

@CalebDu CalebDu closed this as completed Oct 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants