-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEA]question about Cute supports uint1b_t GeMM #1863
Comments
Not having sub-byte pack/unpack should not lead to the GEMM getting incorrect results if you are using So are you trying to accept MN major input tensors or is it only the printing that is broken. |
A/B Tensor are both K major. I find cute has |
You can create a packed CuTe tensor like this: Tensor mA = make_tensor(make_gmem_ptr<uint1b_t>(my_ptr), my_layout_of_1b); which creates a "packed" pointer from Tensor sA = make_tensor(make_smem_ptr<uint1b_t>(my_s_ptr), my_layout_of_1b);
Tensor rA = make_tensor<uint1b_t>(my_layout_of_1b); It looks like you're creating it with a raw |
@ccecka you're right. The root cause is how to call
|
I want to develop customized fusion GeMM kernel for uint1b_t dtype in Cute. I note that Cute has no official example and test about uint1b_t GeMM and MMA_Trait only instantiate
struct MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC>
with wrong ABLayout. So I create PRs to fix wrong ABLayout and append more supports for MMA_Traits and MMA_op for uint1b_t in PR #1856.After then, I start to implement simple GeMM for uint1b_t. But I find Cute doesn't seem to support subbyte pack and unpack like SubbyteReferencein Cutlass. It leads to my customized gemm gets wrong result.
To prove it, I do a simple test as above code, I create a
A
Cultass tensor with shape (1024, 128) in uint1b_t dype and initialize it with 0x1. Cutlass support pack 8 uint1b_t into 1 uint8_t. soA
is allocated 1024*128/8 =16 kB memory. Then I useA.device_data()
ptr to build a Cute Tensor. print both A tensor as following. Cutlass A Tensor print correct result. But Cute A tensor print 255 (0x11111111) becaute uint1b_t stored in uint8_t without bit pack and unpack.I want to know how cute can support uint1b_t GeMM with pack/unpack to get correct result. I can help submit PR.
The text was updated successfully, but these errors were encountered: