Skip to content

Commit

Permalink
5476 cutlass 3x gemm kernels (#1695)
Browse files Browse the repository at this point in the history
Co-authored-by: dePaul Miller <[email protected]>
  • Loading branch information
depaulmillz and depaulmillz authored Aug 8, 2024
1 parent e22ba59 commit 2049c6c
Showing 1 changed file with 115 additions and 44 deletions.
159 changes: 115 additions & 44 deletions python/cutlass_library/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4960,47 +4960,68 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
DataType.bf16, DataType.bf16, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 256, 16],
DataType.f16, DataType.f16, DataType.f16,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 256, 16],
DataType.f16, DataType.f16, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 256, 16],
DataType.bf16, DataType.bf16, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
]

min_cc = 90
max_cc = 90

for math_inst in math_instructions:
tile_descriptions_small = [
# Not compatible with TmaWarpSpecializedCooperative
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions_medium = [
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions_large = [
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]),
# 128x256x128
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions = tile_descriptions_medium + tile_descriptions_large
tile_descriptions = []
tile_descriptions_small = []
tile_descriptions_medium = []
tile_descriptions_large = []

if math_inst.instruction_shape[1] == 128:
tile_descriptions_small = [
# Not compatible with TmaWarpSpecializedCooperative
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions_medium = [
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
]
tile_descriptions_large = [
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions = tile_descriptions_medium + tile_descriptions_large
else:
tile_descriptions = [
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]),
]

data_type = {
"a_type" : math_inst.element_a,
Expand Down Expand Up @@ -5043,7 +5064,7 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
# persistent kernels with TMA epilogues
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
# not enough smem for 256x128 f32 out with C allocation
if data_type["d_type"] == DataType.f32:
if data_type["d_type"] == DataType.f32 and len(tile_descriptions_medium) > 0:
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type,
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
Expand Down Expand Up @@ -5490,20 +5511,30 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
DataType.u8, DataType.u8, DataType.s32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 256, 32],
DataType.s8, DataType.s8, DataType.s32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 256, 32],
DataType.u8, DataType.u8, DataType.s32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
]

min_cc = 90
max_cc = 90

for math_inst in math_instructions:
# 64x128x128
# 64x128x128 or 64x256x128
tile_descriptions_small = [
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
]
# 128x128x128
# 128x128x128 or 128x256x128
tile_descriptions_medium = [
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
Expand Down Expand Up @@ -5670,6 +5701,27 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
DataType.e5m2, DataType.e5m2, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
# inst 64x256x32
MathInstruction(
[64, 256, 32],
DataType.e4m3, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 256, 32],
DataType.e4m3, DataType.e5m2, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 256, 32],
DataType.e5m2, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 256, 32],
DataType.e5m2, DataType.e5m2, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
]

min_cc = 90
Expand Down Expand Up @@ -5788,9 +5840,6 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
# 128x256x128
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions = [
# 128x128x128
Expand All @@ -5801,6 +5850,27 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
elif math_inst.instruction_shape[1] == 256:
tile_descriptions_small = [
# 64x256x128
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions_large = []
tile_descriptions = [
# 128x256x128
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]


else:
assert False, "math inst is not supported"
Expand Down Expand Up @@ -5842,9 +5912,10 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]])

# Large tiles
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_large, data_types_large_tile,
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
if len(tile_descriptions_large) > 0:
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_large, data_types_large_tile,
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]])

# Add stream-K variants (with and without TMA epilogues)
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK])
Expand Down

0 comments on commit 2049c6c

Please sign in to comment.