From a26f6adb83213d932d896827d8f43334e14bb57e Mon Sep 17 00:00:00 2001
From: Cloud Han <cloudhan@outlook.com>
Date: Tue, 28 Nov 2023 20:50:40 +0800
Subject: [PATCH] Fix local_tile for version 3.3 and newer,
 https://github.com/NVIDIA/cutlass/issues/1201

---
 gemm/cuda_cute/04_smem/matmul_smem_0.cu                       | 4 ++--
 .../matmul_smem_and_register_pipelining_0_naive.cu            | 4 ++--
 .../matmul_smem_and_register_pipelining_1_cutlass.cu          | 4 ++--
 .../matmul_smem_and_register_pipelining_2_single_buffer.cu    | 4 ++--
 4 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/gemm/cuda_cute/04_smem/matmul_smem_0.cu b/gemm/cuda_cute/04_smem/matmul_smem_0.cu
index a7e6db2..5d173e4 100644
--- a/gemm/cuda_cute/04_smem/matmul_smem_0.cu
+++ b/gemm/cuda_cute/04_smem/matmul_smem_0.cu
@@ -133,8 +133,8 @@ MATMUL_KERNEL_SIGNATURE(matmul_kernel_smem_0) {
     store_smem_load_global_b<NumThreads, SmemShapeN, SmemShapeK>(sB, stripe_gB(_, _, _0{}, block_p), stripe_cB(_, _, _0{}, block_p), n, k);
     __syncthreads();
 
-    const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), threadIdx.x % (CtaShapeM / ThreadShapeM));
-    const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), threadIdx.x / (CtaShapeM / ThreadShapeM));
+    const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x % (CtaShapeM / ThreadShapeM)));
+    const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x / (CtaShapeM / ThreadShapeM)));
 
 #pragma unroll
     for (int smem_AB_thread_p = 0; smem_AB_thread_p < SmemShapeK; smem_AB_thread_p++) {
diff --git a/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_0_naive.cu b/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_0_naive.cu
index 13b4fb5..4436bcd 100644
--- a/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_0_naive.cu
+++ b/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_0_naive.cu
@@ -166,8 +166,8 @@ __launch_bounds__(NumThreads, 2) MATMUL_KERNEL_SIGNATURE(matmul_smem_and_registe
   load_global_a<NumThreads, SmemShapeM, SmemShapeK, SmemALoadStoreVec>(staging_a, stripe_gA(_, _, _0{}, 0), stripe_cA(_, _, _0{}, 0), m, k);
   load_global_b<NumThreads, SmemShapeN, SmemShapeK, SmemBLoadStoreVec>(staging_b, stripe_gB(_, _, _0{}, 0), stripe_cB(_, _, _0{}, 0), n, k);
 
-  const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), threadIdx.x % (CtaShapeM / ThreadShapeM));
-  const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), threadIdx.x / (CtaShapeM / ThreadShapeM));
+  const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x % (CtaShapeM / ThreadShapeM)));
+  const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x / (CtaShapeM / ThreadShapeM)));
 
   const auto num_smem_block = size<3>(stripe_gA);
 #pragma unroll 1  // no unroll
diff --git a/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_1_cutlass.cu b/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_1_cutlass.cu
index 00af8f1..ce440c3 100644
--- a/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_1_cutlass.cu
+++ b/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_1_cutlass.cu
@@ -173,8 +173,8 @@ MATMUL_KERNEL_SIGNATURE(matmul_smem_and_register_pipelining_1) {
   store_smem_b<NumThreads, SmemShapeN, SmemShapeK, SmemBLoadStoreVec>(sB(_, _, _0{}), staging_b, n, k);
   __syncthreads();
 
-  const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), threadIdx.x % (CtaShapeM / ThreadShapeM));
-  const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), threadIdx.x / (CtaShapeM / ThreadShapeM));
+  const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x % (CtaShapeM / ThreadShapeM)));
+  const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x / (CtaShapeM / ThreadShapeM)));
 
   copy(stripe_sA(_, 0, _0{}, _0{}), fragA[0]);  // load_fragment a
   copy(stripe_sB(_, 0, _0{}, _0{}), fragB[0]);  // load_fragment b
diff --git a/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_2_single_buffer.cu b/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_2_single_buffer.cu
index e1950b9..6b26f4b 100644
--- a/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_2_single_buffer.cu
+++ b/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_2_single_buffer.cu
@@ -171,8 +171,8 @@ MATMUL_KERNEL_SIGNATURE(matmul_smem_and_register_pipelining_2) {
   store_smem_b<NumThreads, SmemShapeN, SmemShapeK, SmemBLoadStoreVec>(sB, staging_b, n, k);
   __syncthreads();
 
-  const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), threadIdx.x % (CtaShapeM / ThreadShapeM));
-  const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), threadIdx.x / (CtaShapeM / ThreadShapeM));
+  const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x % (CtaShapeM / ThreadShapeM)));
+  const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x / (CtaShapeM / ThreadShapeM)));
 
   copy(stripe_sA(_, 0, _0{}), fragA[0]);  // load_fragment a
   copy(stripe_sB(_, 0, _0{}), fragB[0]);  // load_fragment b