diff --git a/gemm/cuda_cute/04_smem/matmul_smem_0.cu b/gemm/cuda_cute/04_smem/matmul_smem_0.cu index a7e6db2..5d173e4 100644 --- a/gemm/cuda_cute/04_smem/matmul_smem_0.cu +++ b/gemm/cuda_cute/04_smem/matmul_smem_0.cu @@ -133,8 +133,8 @@ MATMUL_KERNEL_SIGNATURE(matmul_kernel_smem_0) { store_smem_load_global_b(sB, stripe_gB(_, _, _0{}, block_p), stripe_cB(_, _, _0{}, block_p), n, k); __syncthreads(); - const auto stripe_sA = local_tile(sA, make_tile(Int{}, Int{}), threadIdx.x % (CtaShapeM / ThreadShapeM)); - const auto stripe_sB = local_tile(sB, make_tile(Int{}, Int{}), threadIdx.x / (CtaShapeM / ThreadShapeM)); + const auto stripe_sA = local_tile(sA, make_tile(Int{}, Int{}), make_coord(threadIdx.x % (CtaShapeM / ThreadShapeM))); + const auto stripe_sB = local_tile(sB, make_tile(Int{}, Int{}), make_coord(threadIdx.x / (CtaShapeM / ThreadShapeM))); #pragma unroll for (int smem_AB_thread_p = 0; smem_AB_thread_p < SmemShapeK; smem_AB_thread_p++) { diff --git a/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_0_naive.cu b/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_0_naive.cu index 13b4fb5..4436bcd 100644 --- a/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_0_naive.cu +++ b/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_0_naive.cu @@ -166,8 +166,8 @@ __launch_bounds__(NumThreads, 2) MATMUL_KERNEL_SIGNATURE(matmul_smem_and_registe load_global_a(staging_a, stripe_gA(_, _, _0{}, 0), stripe_cA(_, _, _0{}, 0), m, k); load_global_b(staging_b, stripe_gB(_, _, _0{}, 0), stripe_cB(_, _, _0{}, 0), n, k); - const auto stripe_sA = local_tile(sA, make_tile(Int{}, Int{}), threadIdx.x % (CtaShapeM / ThreadShapeM)); - const auto stripe_sB = local_tile(sB, make_tile(Int{}, Int{}), threadIdx.x / (CtaShapeM / ThreadShapeM)); + const auto stripe_sA = local_tile(sA, make_tile(Int{}, Int{}), make_coord(threadIdx.x % (CtaShapeM / ThreadShapeM))); + const auto stripe_sB = local_tile(sB, make_tile(Int{}, Int{}), make_coord(threadIdx.x / (CtaShapeM / ThreadShapeM))); const auto num_smem_block = size<3>(stripe_gA); #pragma unroll 1 // no unroll diff --git a/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_1_cutlass.cu b/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_1_cutlass.cu index 00af8f1..ce440c3 100644 --- a/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_1_cutlass.cu +++ b/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_1_cutlass.cu @@ -173,8 +173,8 @@ MATMUL_KERNEL_SIGNATURE(matmul_smem_and_register_pipelining_1) { store_smem_b(sB(_, _, _0{}), staging_b, n, k); __syncthreads(); - const auto stripe_sA = local_tile(sA, make_tile(Int{}, Int{}), threadIdx.x % (CtaShapeM / ThreadShapeM)); - const auto stripe_sB = local_tile(sB, make_tile(Int{}, Int{}), threadIdx.x / (CtaShapeM / ThreadShapeM)); + const auto stripe_sA = local_tile(sA, make_tile(Int{}, Int{}), make_coord(threadIdx.x % (CtaShapeM / ThreadShapeM))); + const auto stripe_sB = local_tile(sB, make_tile(Int{}, Int{}), make_coord(threadIdx.x / (CtaShapeM / ThreadShapeM))); copy(stripe_sA(_, 0, _0{}, _0{}), fragA[0]); // load_fragment a copy(stripe_sB(_, 0, _0{}, _0{}), fragB[0]); // load_fragment b diff --git a/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_2_single_buffer.cu b/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_2_single_buffer.cu index e1950b9..6b26f4b 100644 --- a/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_2_single_buffer.cu +++ b/gemm/cuda_cute/05_pipelining/matmul_smem_and_register_pipelining_2_single_buffer.cu @@ -171,8 +171,8 @@ MATMUL_KERNEL_SIGNATURE(matmul_smem_and_register_pipelining_2) { store_smem_b(sB, staging_b, n, k); __syncthreads(); - const auto stripe_sA = local_tile(sA, make_tile(Int{}, Int{}), threadIdx.x % (CtaShapeM / ThreadShapeM)); - const auto stripe_sB = local_tile(sB, make_tile(Int{}, Int{}), threadIdx.x / (CtaShapeM / ThreadShapeM)); + const auto stripe_sA = local_tile(sA, make_tile(Int{}, Int{}), make_coord(threadIdx.x % (CtaShapeM / ThreadShapeM))); + const auto stripe_sB = local_tile(sB, make_tile(Int{}, Int{}), make_coord(threadIdx.x / (CtaShapeM / ThreadShapeM))); copy(stripe_sA(_, 0, _0{}), fragA[0]); // load_fragment a copy(stripe_sB(_, 0, _0{}), fragB[0]); // load_fragment b