Skip to content

Commit

Permalink
Fix local_tile for version 3.3 and newer, NVIDIA/cutlass#1201
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Dec 12, 2023
1 parent f804b4c commit a26f6ad
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions gemm/cuda_cute/04_smem/matmul_smem_0.cu
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ MATMUL_KERNEL_SIGNATURE(matmul_kernel_smem_0) {
store_smem_load_global_b<NumThreads, SmemShapeN, SmemShapeK>(sB, stripe_gB(_, _, _0{}, block_p), stripe_cB(_, _, _0{}, block_p), n, k);
__syncthreads();

const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), threadIdx.x % (CtaShapeM / ThreadShapeM));
const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), threadIdx.x / (CtaShapeM / ThreadShapeM));
const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x % (CtaShapeM / ThreadShapeM)));
const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x / (CtaShapeM / ThreadShapeM)));

#pragma unroll
for (int smem_AB_thread_p = 0; smem_AB_thread_p < SmemShapeK; smem_AB_thread_p++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ __launch_bounds__(NumThreads, 2) MATMUL_KERNEL_SIGNATURE(matmul_smem_and_registe
load_global_a<NumThreads, SmemShapeM, SmemShapeK, SmemALoadStoreVec>(staging_a, stripe_gA(_, _, _0{}, 0), stripe_cA(_, _, _0{}, 0), m, k);
load_global_b<NumThreads, SmemShapeN, SmemShapeK, SmemBLoadStoreVec>(staging_b, stripe_gB(_, _, _0{}, 0), stripe_cB(_, _, _0{}, 0), n, k);

const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), threadIdx.x % (CtaShapeM / ThreadShapeM));
const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), threadIdx.x / (CtaShapeM / ThreadShapeM));
const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x % (CtaShapeM / ThreadShapeM)));
const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x / (CtaShapeM / ThreadShapeM)));

const auto num_smem_block = size<3>(stripe_gA);
#pragma unroll 1 // no unroll
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ MATMUL_KERNEL_SIGNATURE(matmul_smem_and_register_pipelining_1) {
store_smem_b<NumThreads, SmemShapeN, SmemShapeK, SmemBLoadStoreVec>(sB(_, _, _0{}), staging_b, n, k);
__syncthreads();

const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), threadIdx.x % (CtaShapeM / ThreadShapeM));
const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), threadIdx.x / (CtaShapeM / ThreadShapeM));
const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x % (CtaShapeM / ThreadShapeM)));
const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x / (CtaShapeM / ThreadShapeM)));

copy(stripe_sA(_, 0, _0{}, _0{}), fragA[0]); // load_fragment a
copy(stripe_sB(_, 0, _0{}, _0{}), fragB[0]); // load_fragment b
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ MATMUL_KERNEL_SIGNATURE(matmul_smem_and_register_pipelining_2) {
store_smem_b<NumThreads, SmemShapeN, SmemShapeK, SmemBLoadStoreVec>(sB, staging_b, n, k);
__syncthreads();

const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), threadIdx.x % (CtaShapeM / ThreadShapeM));
const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), threadIdx.x / (CtaShapeM / ThreadShapeM));
const auto stripe_sA = local_tile(sA, make_tile(Int<ThreadShapeM>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x % (CtaShapeM / ThreadShapeM)));
const auto stripe_sB = local_tile(sB, make_tile(Int<ThreadShapeN>{}, Int<SmemShapeK>{}), make_coord(threadIdx.x / (CtaShapeM / ThreadShapeM)));

copy(stripe_sA(_, 0, _0{}), fragA[0]); // load_fragment a
copy(stripe_sB(_, 0, _0{}), fragB[0]); // load_fragment b
Expand Down

0 comments on commit a26f6ad

Please sign in to comment.