Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Performance Issue of doing GEMM on A100 using CuTe #1858

Open
Yanksi opened this issue Oct 10, 2024 · 19 comments
Open

[QST] Performance Issue of doing GEMM on A100 using CuTe #1858

Yanksi opened this issue Oct 10, 2024 · 19 comments

Comments

@Yanksi
Copy link

Yanksi commented Oct 10, 2024

Hi, I've just created a small project (link to the project) by modifying the sgemm_sm80 example. What I was doing was trying to make use of the tensor cores for doing the computation. Unfortunately, when testing on A100, the performance seems never being able to reach the peak performance. Following is the best results I've got from the autotuning process. The peak performance reached seems to be always a bit less than half of the theoretical peak performance provided by A100. Any comments on how can I make this better?


Results for float
                         name       TN  NT
523  gemm_float_config1672_TN  63901.4 NaN
593  gemm_float_config2318_TN  62808.7 NaN
769  gemm_float_config1708_TN  61919.7 NaN
680  gemm_float_config2354_TN  61474.6 NaN
895  gemm_float_config1702_TN  60324.8 NaN
                          name  TN       NT
1871  gemm_float_config1449_NT NaN  62819.8
2047  gemm_float_config2390_NT NaN  60782.0
1857  gemm_float_config1412_NT NaN  60040.9
1902  gemm_float_config1448_NT NaN  58810.4
2309  gemm_float_config1376_NT NaN  58670.4



Results for half
                         name        TN  NT
402   gemm_half_config2750_TN  155641.2 NaN
982   gemm_half_config2752_TN  154693.0 NaN
749   gemm_half_config1804_TN  147818.7 NaN
194   gemm_half_config2456_TN  146790.2 NaN
1054  gemm_half_config1808_TN  145774.8 NaN
                         name  TN       NT
1844  gemm_half_config1840_NT NaN  92024.8
2407  gemm_half_config2492_NT NaN  91585.4
1225  gemm_half_config1844_NT NaN  91393.5
1885  gemm_half_config1808_NT NaN  90688.7
1461  gemm_half_config1514_NT NaN  90666.0
@ccecka
Copy link

ccecka commented Oct 11, 2024

Your rewrite of sgemm_sm80 with the extra tiling of mma_k and extra loop is unnecessary and could inhibit the compiler from properly pipelining smem and rmem and mma. I don't recommend that.

Beyond that, it appears that you're only using row-major/col-major smem, SIMT smem->rmem TiledCopys, and trivial TiledMMAs. I agree that we should include a non-trivial SM80 tensor-core example in the tutorial...

I did need to modify the sgemm_sm80 kernel slightly to extend smem and got this performance

** New HALF_T for all TA, TB, TC, TI **
$ ./sgemm_sm80 5120 5120 4096 T N
Using device 0: NVIDIA A100-PCIE-40GB (SM80, 108)
M = 5120
N = 5120
K = 4096
C = A^T B^N
CUTE_GEMM:     [126161.8]GFlop/s  (1.7022)ms

The next step would be to use LDSM for the smem->rmem load (this smem layout is already designed for the LDSM pattern... and is what is slowing it down now), which I've left notes on and we could look at CUTLASS's SM80 Collective to see how that's done. That should achieve speed-of-light, peak performance on A100.

Here's my diff/patch/half_t configuration

diff --git a/cutlass/examples/cute/tutorial/sgemm_sm80.cu b/cutlass/examples/cute/tutorial/sgemm_sm80.cu
index d042933040..69b8457c1c 100644
--- a/cutlass/examples/cute/tutorial/sgemm_sm80.cu
+++ b/cutlass/examples/cute/tutorial/sgemm_sm80.cu
@@ -46,6 +46,16 @@
 #include "cutlass/util/helper_cuda.hpp"
 
+template <class ElementA,
+          class ElementB,
+          class SmemLayoutA,
+          class SmemLayoutB>
+struct SharedStorage
+{
+  cute::array<ElementA, cute::cosize_v<SmemLayoutA>> A;
+  cute::array<ElementB, cute::cosize_v<SmemLayoutB>> B;
+};
+
 template <class ProblemShape, class CtaTiler,
           class TA, class AStride, class ASmemLayout, class TiledCopyA,
           class TB, class BStride, class BSmemLayout, class TiledCopyB,
@@ -100,10 +110,11 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
   Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{});  // (BLK_M,BLK_N)
 
   // Shared memory buffers
-  __shared__ TA smemA[cosize_v<ASmemLayout>];
-  __shared__ TB smemB[cosize_v<BSmemLayout>];
-  Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout);            // (BLK_M,BLK_K,PIPE)
-  Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout);            // (BLK_N,BLK_K,PIPE)
+  extern __shared__ char shared_memory[];
+  using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
+  SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
+  Tensor sA = make_tensor(make_smem_ptr(smem.A.data()), sA_layout);    // (BLK_M,BLK_K,PIPE)
+  Tensor sB = make_tensor(make_smem_ptr(smem.B.data()), sB_layout);    // (BLK_N,BLK_K,PIPE)
 
   //
   // Partition the copying of A and B tiles across the threads
@@ -301,6 +312,127 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
   axpby(alpha, tCrC, beta, tCgC);
 }
 
+template <class Alpha, class Beta>
+void
+gemm_nt(int m, int n, int k,
+        Alpha alpha,
+        cute::half_t const* A, int ldA,
+        cute::half_t const* B, int ldB,
+        Beta beta,
+        cute::half_t      * C, int ldC,
+        cudaStream_t stream = 0)
+{
+  assert(false && "Not implemented");
+}
+
+// Setup params for a TN HGEMM
+template <class Alpha, class Beta>
+void
+gemm_tn(int m, int n, int k,
+        Alpha alpha,
+        cute::half_t const* A, int ldA,
+        cute::half_t const* B, int ldB,
+        Beta beta,
+        cute::half_t      * C, int ldC,
+        cudaStream_t stream = 0)
+{
+  using namespace cute;
+
+  // Define shapes (dynamic)
+  auto M = int(m);
+  auto N = int(n);
+  auto K = int(k);
+  auto prob_shape = make_shape(M, N, K);                     // (M, N, K)
+
+  // Define TN strides (mixed)
+  auto dA = make_stride(ldA, Int<1>{});                      // (dM, dK)
+  auto dB = make_stride(ldB, Int<1>{});                      // (dN, dK)
+  auto dC = make_stride(Int<1>{}, ldC);                      // (dM, dN)
+
+  // Define CTA tile sizes (static)
+  auto bM = Int<128>{};
+  auto bN = Int<128>{};
+  auto bK = Int< 64>{};
+  auto cta_tiler = make_shape(bM, bN, bK);                   // (BLK_M, BLK_N, BLK_K)
+  auto bP = Int<3>{};  // Pipeline
+
+  // Define the smem layouts (static)
+  // Swizzles for LDSM and 128b k-major loads
+  auto swizzle_atom = composition(Swizzle<3,3,3>{},
+                                  Layout<Shape <_8,Shape <_8, _8>>,
+                                         Stride<_8,Stride<_1,_64>>>{});
+  auto sA = tile_to_shape(swizzle_atom, make_shape(bM,bK,bP));
+  auto sB = tile_to_shape(swizzle_atom, make_shape(bN,bK,bP));
+  auto sC = make_layout(make_shape(bM, bN));
+
+  // Define the thread layouts (static)
+
+  TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::half_t>{},
+                                    Layout<Shape<_32,_8>,Stride<_8,_1>>{},  // Thr layout 32x8 k-major
+                                    Layout<Shape< _1,_8>>{});               // Val layout  1x8 k-major
+  TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::half_t>{},
+                                    Layout<Shape<_32,_8>,Stride<_8,_1>>{},  // Thr layout 32x8 k-major
+                                    Layout<Shape< _1,_8>>{});               // Val layout  1x8 n-major
+
+  TiledMMA mmaC = make_tiled_mma(SM80_16x8x16_F16F16F16F16_TN{},
+                                 Layout<Shape<_2,_4>>{},    // 2x4x1 MMA Atoms
+                                 Tile<_32,_64,_16>{});      // 32x64x16 Tiled MMA for LDSM
+
+  // //Copy_Atom<DefaultCopy, half_t> copy_atom_A;
+  // //Copy_Atom<UniversalCopy<half_t>, half_t> copy_atom_A;
+  // //Copy_Atom<SM75_U32x1_LDSM_N, half_t> copy_atom_A;
+  // //Copy_Atom<SM75_U32x2_LDSM_N, half_t> copy_atom_A;
+  // Copy_Atom<SM75_U32x4_LDSM_N, half_t> copy_atom_A;
+  // TiledCopy copyA = make_tiled_copy_A(copy_atom_A, mmaC);
+
+  // //Copy_Atom<DefaultCopy, half_t> copy_atom_B;
+  // //Copy_Atom<UniversalCopy<half_t>, half_t> copy_atom_B;
+  // //Copy_Atom<SM75_U32x1_LDSM_N, half_t> copy_atom_B;
+  // //Copy_Atom<SM75_U32x2_LDSM_N, half_t> copy_atom_B;
+  // Copy_Atom<SM75_U32x4_LDSM_N, half_t> copy_atom_B;
+  // TiledCopy copyB = make_tiled_copy_B(copy_atom_B, mmaC);
+
+#if 0
+  print(copyA);
+  print(copyB);
+  print(mmaC);
+#endif
+
+#if 0
+  print_latex(copyA);
+  print_latex(copyB);
+  print_latex(mmaC);
+#endif
+
+  int smem_size = int(sizeof(SharedStorage<cute::half_t, cute::half_t, decltype(sA), decltype(sB)>));
+  dim3 dimBlock(size(mmaC));
+  dim3 dimGrid(size(ceil_div(M, bM)),
+               size(ceil_div(N, bN)));
+
+  auto kernel_fptr = gemm_device<
+    decltype(prob_shape), decltype(cta_tiler),
+    cute::half_t, decltype(dA), decltype(sA), decltype(copyA),
+    cute::half_t, decltype(dB), decltype(sB), decltype(copyB),
+    cute::half_t, decltype(dC), decltype(sC), decltype(mmaC),
+    decltype(alpha), decltype(beta)>;
+
+  // Set L1 to be SMEM only
+  cudaFuncSetAttribute(
+    kernel_fptr,
+    cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
+
+  cudaFuncSetAttribute(
+    kernel_fptr,
+    cudaFuncAttributePreferredSharedMemoryCarveout, 100);
+
+  kernel_fptr<<<dimGrid, dimBlock, smem_size, stream>>>
+      (prob_shape, cta_tiler,
+       A, dA, sA, copyA,
+       B, dB, sB, copyB,
+       C, dC, sC, mmaC,
+       alpha, beta);
+}
+
 // Setup params for a NT GEMM
 template <class TA, class TB, class TC,
           class Alpha, class Beta>
@@ -362,10 +494,11 @@ gemm_nt(int m, int n, int k,
   print_latex(mmaC);
 #endif
 
+  int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
   dim3 dimBlock(size(mmaC));
   dim3 dimGrid(size(ceil_div(M, bM)),
                size(ceil_div(N, bN)));
-  gemm_device<<<dimGrid, dimBlock, 0, stream>>>
+  gemm_device<<<dimGrid, dimBlock, smem_size, stream>>>
       (prob_shape, cta_tiler,
        A, dA, sA, copyA,
        B, dB, sB, copyB,
@@ -438,10 +571,11 @@ gemm_tn(int m, int n, int k,
   print_latex(mmaC);
 #endif
 
+  int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
   dim3 dimBlock(size(mmaC));
   dim3 dimGrid(size(ceil_div(M, bM)),
                size(ceil_div(N, bN)));
-  gemm_device<<<dimGrid, dimBlock, 0, stream>>>
+  gemm_device<<<dimGrid, dimBlock, smem_size, stream>>>
       (prob_shape, cta_tiler,
        A, dA, sA, copyA,
        B, dB, sB, copyB,
@@ -485,6 +619,11 @@ int main(int argc, char** argv)
     return 0;
   }
 
+  std::cout << "Using device 0: " << props.name
+            << " (SM" << props.major * 10 + props.minor
+            << ", " << props.multiProcessorCount
+            << ")" << std::endl;
+
   int m = 5120;
   if (argc >= 2)
     sscanf(argv[1], "%d", &m);
@@ -505,13 +644,13 @@ int main(int argc, char** argv)
   if (argc >= 6)
     sscanf(argv[5], "%c", &transB);
 
-  using TA = float;
-  using TB = float;
-  using TC = float;
-  using TI = float;
+  using TA = cute::half_t;
+  using TB = cute::half_t;
+  using TC = cute::half_t;
+  using TI = cute::half_t;
 
-  TI alpha = 1.0;
-  TI beta  = 0.0;
+  TI alpha = static_cast<TI>(1.0f);
+  TI beta  = static_cast<TI>(0.0f);
 
   std::cout << "M = " << m << std::endl;
   std::cout << "N = " << n << std::endl;

@Yanksi
Copy link
Author

Yanksi commented Oct 15, 2024

I updated my project these days. I have included a README.md for better documentation (hopefully).

Your rewrite of sgemm_sm80 with the extra tiling of mma_k and extra loop is unnecessary and could inhibit the compiler from properly pipelining smem and rmem and mma. I don't recommend that.

The extra tiling of mma_k was to match up the smem->rmem pipeline shown in the sgemm_sm80 example. I'm not sure whether it will prevent the compiler from doing the optimization, but the performance did not drop after that was implemented from my experience.

The best performance I got on A100 when using FP16 datatype after autotuning is 155641.2 GFlops/s. On the other hand, the reference cublas gemm gave me a performance about 22 TFlops (Sorry I cannot recall the exact number and I am not able to use an A100 recently).

I am running out of the ideas for doing further optimization at this point. If anyone can take a look at my code and figure out what could the next optimization be, I would be greatly appreciated!!

@ccecka
Copy link

ccecka commented Oct 16, 2024

I recall something about SM80_16x8x16_F16F16F16F16_TN being significantly more difficult to optimize than SM80_16x8x16_F32F16F16F32_TN, though I forget the details as it's been a while since deep work on Ampere.

Regardless, to give you an example of the same SM80 kernel using LDSM, I've attached my CuTe example for half_t. This kernel should be equivalent to the CUTLASS SM80 Collective implementation and accept very similar configuration parameters.

sgemm_sm80_tmp.txt

@Yanksi
Copy link
Author

Yanksi commented Oct 16, 2024

Thanks a lot for your quick respond!!

I noticed that a class called Swizzle is used in your implementation. I think it is supposed to be used for avoiding bank conflicts when doing smem->rmem. However, I am not able to find the documentation for that class. Could you please explain a bit what exactly is

auto swizzle_atom = composition(Swizzle<3,3,3>{},
                                  Layout<Shape <_8,Shape <_8, _8>>,
                                         Stride<_8,Stride<_1,_64>>>{});

doing?

About those LDSM copying atoms, I think their naming conventions follows as SM75_<dtype>x<#items>_LDSM_<layout>? If I understand this correctly, then why in you were using <dtype>=U32 in your code if we are dealing with half? Also, you were using the same layout for both A and B in gemm_nt, which from my perspective, it should different to make more sense? Another question of mine regarding to those LDSM atoms is that what's the difference between these atoms with different <#items> exactly? And what would be the underlying layout for copying in this case?

I noticed the definition of those layout in copy_traits_sm75.hpp. And I am confused about the meaning of // Map from (src-thr,src-val) to bit.

@ccecka
Copy link

ccecka commented Oct 16, 2024

print_latex will answer all of those questions. Call it on Layouts, TiledCopy, TiledMMA, etc.

@ghostplant
Copy link

I recall something about SM80_16x8x16_F16F16F16F16_TN being significantly more difficult to optimize than SM80_16x8x16_F32F16F16F32_TN, though I forget the details as it's been a while since deep work on Ampere.

Regardless, to give you an example of the same SM80 kernel using LDSM, I've attached my CuTe example for half_t. This kernel should be equivalent to the CUTLASS SM80 Collective implementation and accept very similar configuration parameters.

sgemm_sm80_tmp.txt

Do you think it can be further improved? The example your provided is currently 150TFlops on A100, while CUBLAS gets 300TFlops.

@Yanksi
Copy link
Author

Yanksi commented Oct 23, 2024

I found this high performance implementation of gemm using cute. The author wrote a series of tutorial for CuTe (in Chinese) on Zhihu. And in one of the tutorials, the author claimed this implementation have reached a CUBLAS level performance on RTX 3090. I am not sure whether it will be the same case on A100 as I currently don't have access to it. I think this series of tutorial is a very good complement to the official CuTe's documentation.

@ghostplant
Copy link

I found this high performance implementation of gemm using cute. The author wrote a series of tutorial for CuTe (in Chinese) on Zhihu. And in one of the tutorials, the author claimed this implementation have reached a CUBLAS level performance on RTX 3090. I am not sure whether it will be the same case on A100 as I currently don't have access to it. I think this series of tutorial is a very good complement to the official CuTe's documentation.

I tried that before, looks like there are compilation issues due to interface no longer match with latest cutlass's CUTE

@Yanksi
Copy link
Author

Yanksi commented Oct 28, 2024

I found this high performance implementation of gemm using cute. The author wrote a series of tutorial for CuTe (in Chinese) on Zhihu. And in one of the tutorials, the author claimed this implementation have reached a CUBLAS level performance on RTX 3090. I am not sure whether it will be the same case on A100 as I currently don't have access to it. I think this series of tutorial is a very good complement to the official CuTe's documentation.

I tried that before, looks like there are compilation issues due to interface no longer match with latest cutlass's CUTE

@ccecka 's code is actually working in the exact same way as that in that code base. I believe all the optimizations that was in that code base except the epilog for saving the saving the computation result back into the global memory has been adopted into @ccecka 's code.

@Yanksi
Copy link
Author

Yanksi commented Oct 28, 2024

I recall something about SM80_16x8x16_F16F16F16F16_TN being significantly more difficult to optimize than SM80_16x8x16_F32F16F16F32_TN, though I forget the details as it's been a while since deep work on Ampere.

Regardless, to give you an example of the same SM80 kernel using LDSM, I've attached my CuTe example for half_t. This kernel should be equivalent to the CUTLASS SM80 Collective implementation and accept very similar configuration parameters.

sgemm_sm80_tmp.txt

I am currently encounter problem for understanding the LDSM part of the code. @ccecka

To understand the code, I wrote the following example for play with the LDSM copys.

#include <cute/tensor.hpp>
#include <iostream>

int main() {
    half_t* A = new half_t[256*32];
    Copy_Atom<SM75_U32x2_LDSM_N, half_t> s2r_atom_a;
    TiledMMA mmaC = make_tiled_mma(SM80_16x8x8_F16F16F16F16_TN{},
                                 Layout<Shape<_2,_2>>{},    // 2x2x1 MMA Atoms
                                 Tile<_32,_32,_16>{});      // 32x32x16 Tiled MMA for LDSM
    // print_latex(mmaC);
    Tensor sA = make_tensor(A, make_layout(make_shape(_64{}, _32{}, _2{})));
    Tensor sB = make_tensor(A, make_layout(make_shape(_64{}, _32{}, _2{})));
    ThrMMA thr_mma = mmaC.get_slice(0);
    Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0));        // (MMA,MMA_M,MMA_K)
    Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0));        // (MMA,MMA_N,MMA_K)

    TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mmaC);
    ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(0);
    Tensor tXsA = s2r_thr_copy_a.partition_S(sA);                 // (CPY,MMA_M,MMA_K,PIPE)
    Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA);                  // (CPY,MMA_M,MMA_K)
    printf("tCrA: "); print(tCrA); printf("\n");
    printf("tXsA: "); print(tXsA); printf("\n");
    printf("tXrA: "); print(tXrA); printf("\n");

    printf("\n");

    TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_a, mmaC);
    ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(0);
    Tensor tXsB = s2r_thr_copy_b.partition_S(sB);                 // (CPY,MMA_N,MMA_K,PIPE)
    Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB);                  // (CPY,MMA_N,MMA_K)
    
    printf("tCrB: "); print(tCrB); printf("\n");
    printf("tXsB: "); print(tXsB); printf("\n");
    printf("tXrB: "); print(tXrB); printf("\n");
}

The mma atom I used there is a 16x8x8, which is different from that in your code. And in this case, no matter which LDSM atom in the form of SM75_U32x<i>_LDSM_N I choose for s2r_atom_a, I always get incompatible value for MMA_K between tXrA and tCrA. The same situation also happens between tXrB and tCrB. This incompatibility is breaking the smem->rmem pipeline. Could you please help me resolving this issue without changing the mma atom used? I need to use this atom as I want to deal with bmm with small ks.

@ccecka
Copy link

ccecka commented Oct 29, 2024

The incompatibility is coming from the layout of smem.

You can view the access patterns for the source and destination tensors of TiledCopy:

print_latex(s2r_copy_a);

which yields
Image
You can see that, in the source tensor, each thread wants to access 16 elements across each row. So the SMEM should be at least row-major. I suggest starting with

    Tensor sA = make_tensor(A, Layout<Shape <_64,_32,   _2>,
                                      Stride<_32, _1,_2048>>{});

@Yanksi
Copy link
Author

Yanksi commented Oct 31, 2024

The incompatibility is coming from the layout of smem.

You can view the access patterns for the source and destination tensors of TiledCopy:

print_latex(s2r_copy_a);
which yields Image You can see that, in the source tensor, each thread wants to access 16 elements across each row. So the SMEM should be at least row-major. I suggest starting with

Tensor sA = make_tensor(A, Layout<Shape <_64,_32,   _2>,
                                  Stride<_32, _1,_2048>>{});

The incompatibility is still there even I have initialized sA as suggested.

Here's the modified version of the program

int main() {
    using namespace cute;
    half_t* A = new half_t[256*32];
    Copy_Atom<SM75_U32x1_LDSM_N, half_t> s2r_atom_a;
    TiledMMA mmaC = make_tiled_mma(SM80_16x8x8_F16F16F16F16_TN{},
                                 Layout<Shape<_2,_2>>{},    // 2x2x1 MMA Atoms
                                 Tile<_32,_32,_16>{});      // 32x32x16 Tiled MMA for LDSM

    Tensor sA = make_tensor(A, Layout<Shape <_64,_32,   _2>,
                                      Stride<_32, _1,_2048>>{});
    Tensor sB = make_tensor(A, Layout<Shape <_64,_32,   _2>,
                                      Stride<_32, _1,_2048>>{});
    
    ThrMMA thr_mma = mmaC.get_slice(0);
    Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0));        // (MMA,MMA_M,MMA_K)
    Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0));        // (MMA,MMA_N,MMA_K)
    Tensor tCsA = thr_mma.partition_A(sA(_,_,0));                 // (MMA,MMA_M,MMA_K)
    Tensor tCsB = thr_mma.partition_B(sB(_,_,0));                 // (MMA,MMA_N,MMA_K)

    TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mmaC);
    ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(0);
    Tensor tXsA = s2r_thr_copy_a.partition_S(sA);                 // (CPY,MMA_M,MMA_K,PIPE)
    Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA);
    print("tCsA: "); print(tCsA); print("\n");                  // (CPY,MMA_M,MMA_K)
    print("tCrA: "); print(tCrA); print("\n");
    print("tXsA: "); print(tXsA); print("\n");
    print("tXrA: "); print(tXrA); print("\n");
    print("\n");

    TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_a, mmaC);
    ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(0);
    Tensor tXsB = s2r_thr_copy_b.partition_S(sB);                 // (CPY,MMA_N,MMA_K,PIPE)
    Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB);                  // (CPY,MMA_N,MMA_K)
    print("tCsB: "); print(tCsB); print("\n"); 
    print("tCrB: "); print(tCrB); print("\n");
    print("tXsB: "); print(tXsB); print("\n");
    print("tXrB: "); print(tXrB); print("\n");
}

And here's the output

tCsA: ptr[16b](0x1bd41d0) o ((_2,_2),_2,_4):((_1,_256),_1024,_8)
tCrA: ptr[16b](0x7ffe5f2a4b50) o ((_2,_2),_2,_4):((_1,_2),_16,_4)
tXsA: ptr[16b](0x1bd41d0) o ((_8,(_2,_2)),_2,_2,_2):((_1,(_256,_8)),_1024,_16,_2048)
tXrA: ptr[16b](0x7ffe5f2a4b50) o ((_2,_4),_2,_2):((_1,_2),_16,_8)

tCsB: ptr[16b](0x1bd41d0) o (_2,_4,_4):(_1,_512,_8)
tCrB: ptr[16b](0x7ffe5f2a4b10) o (_2,_4,_4):(_1,_8,_2)
tXsB: ptr[16b](0x1bd41d0) o ((_8,(_2,_2)),_2,_2,_2):((_1,(_512,_8)),_1024,_16,_2048)
tXrB: ptr[16b](0x7ffe5f2a4b10) o ((_2,(_2,_2)),_2,_2):((_1,(_8,_2)),_16,_4)

@ccecka
Copy link

ccecka commented Oct 31, 2024

I see, you're concerned about the MMA_K mode between tCsA and tXsA. Duplicating an MMA in the K-mode requires some extra care, yes. I recommend avoiding that:

    TiledMMA mmaC = make_tiled_mma(SM80_16x8x8_F16F16F16F16_TN{},
                                   Layout<Shape<_2,_2>>{},    // 2x2x1 MMA Atoms
                                   Tile<_32,_32>{});          // 32x32x8 Tiled MMA for LDSM

You should be able to expand the M- and N-size of the TiledMMA to accommodate the LDSMs. This will give you a finer granularity for interleaving in the mainloop as well. Sorry for the confusion.

@Yanksi
Copy link
Author

Yanksi commented Nov 6, 2024

I see, you're concerned about the MMA_K mode between tCsA and tXsA. Duplicating an MMA in the K-mode requires some extra care, yes. I recommend avoiding that:

TiledMMA mmaC = make_tiled_mma(SM80_16x8x8_F16F16F16F16_TN{},
                               Layout<Shape<_2,_2>>{},    // 2x2x1 MMA Atoms
                               Tile<_32,_32>{});          // 32x32x8 Tiled MMA for LDSM

You should be able to expand the M- and N-size of the TiledMMA to accommodate the LDSMs. This will give you a finer granularity for interleaving in the mainloop as well. Sorry for the confusion.

Thanks! That solved the problem.

I found that CuTe would not stop me from creating a TiledMMA with the size of Tile smaller than what it supposed to be. For example, for the following code

int main() {
    using namespace cute;
    half_t* A = new half_t[256*32*3];
    Copy_Atom<SM75_U32x1_LDSM_N, half_t> s2r_atom_a;
    Copy_Atom<AutoVectorizingCopy, half_t> s2r_atom_b;
    auto bM = _256{};
    auto bN = _256{};
    auto bK = _16{};

    TiledMMA mmaC = make_tiled_mma(SM80_16x8x8_F16F16F16F16_TN{},
                                Layout<Shape<_4,_2>>{},
                                Tile<_32, _32>{});      // 32x32x16 Tiled MMA for LDSM
    auto tile_size_m = mmaC.tile_size_mnk<0>();
    auto tile_size_n = mmaC.tile_size_mnk<1>();
    auto tile_size_k = mmaC.tile_size_mnk<2>();

    
    Tensor sA = make_tensor(A, Layout<Shape <_64,_32,   _2>,
                                      Stride<_32, _1,_2048>>{});

    ThrMMA thr_mma = mmaC.get_slice(0);
    Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0));        // (MMA,MMA_M,MMA_K)
    print(tCrA);print("\n");
}

CuTe seems to be quite happy about it, while the value of of tile_size_m here would be 64. Which would be 2 times the size of the size of M defined in the tiling parameter. When do print_latex on mmaC, it is indeed producing some wired layout for the A part of mma. In this case, is there actually a static check for the tiling missing in CuTe or it is just designed to be this way?

Image

@thakkarV
Copy link
Collaborator

thakkarV commented Nov 6, 2024

Very much intentional and often used this way. We iterate over the rest modes of the partitioned tensors too.

@Yanksi
Copy link
Author

Yanksi commented Nov 6, 2024

Very much intentional and often used this way. We iterate over the rest modes of the partitioned tensors too.

In this case, how mma will be performed? The size of tCrA produced by the given code is ((_2,_2),_1,_4):((_1,_2),_0,_4), and it remains the same when I changed sA to be Tensor sA = make_tensor(A, Layout<Shape <_32, _32, _2>, Stride<_32, _1,_1024>>{});. Which really confused me.

Copy link

github-actions bot commented Dec 6, 2024

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

@ccecka
Copy link

ccecka commented Dec 6, 2024

Correct, this is a bug and there should be additional static assertions on construction to prevent incompatible parameters (or override in the case of unnecessary identity Permutations here) like you mention.

Copy link

github-actions bot commented Jan 5, 2025

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants