Skip to content

Commit

Permalink
Improve FP8 grouped GEMM perf via tileshape and cooperative (pytorch#…
Browse files Browse the repository at this point in the history
…3653)

Summary:
X-link: facebookresearch/FBGEMM#729

Pull Request resolved: pytorch#3653

Tuning tileshape and leveraging cooperative bring **additional up to 1.4x speedup** compared to the existing FP8 grouped GEMM kernel configs for non-memory-bound shapes

Reviewed By: jianyuh, jwfromm

Differential Revision: D68609019

fbshipit-source-id: e5c5680d30b60a97d0bfe50906600f133e0f2391
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed Feb 2, 2025
1 parent 4957ca1 commit 98d54f7
Showing 1 changed file with 19 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,17 @@ __global__ void set_kernel_args_kernel(
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape*>(
problem_shape_buf);
// Pass dummy configs to get Stride structure
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputA* stride_input_A_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputA*>(stride_buf);
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputB* stride_input_B_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputB*>(stride_buf + stride_size);
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideOutput* stride_output_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideOutput*>(stride_buf + (stride_size * 2));

output_args_ptr[group_index] =
Expand All @@ -210,15 +210,15 @@ __global__ void set_kernel_args_kernel(
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape(M, N, K);
stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputA{},
{M, K, 1});
stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputB{},
{N, K, 1});
stride_output_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideOutput{},
{M, N, 1});
}
}
Expand Down Expand Up @@ -263,17 +263,17 @@ __global__ void set_dynamic_kernel_args_kernel(
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape*>(
problem_shape_buf);
// Pass dummy configs to get Stride structure
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputA* stride_input_A_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputA*>(stride_buf);
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputB* stride_input_B_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputB*>(stride_buf + stride_size);
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideOutput* stride_output_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideOutput*>(stride_buf + (stride_size * 2));

output_args_ptr[group_index] =
Expand All @@ -289,15 +289,15 @@ __global__ void set_dynamic_kernel_args_kernel(
zero_start_index_M[group_index], N, K);
stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputA{},
{zero_start_index_M[group_index], K, 1});
stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputB{},
{N, K, 1});
stride_output_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideOutput{},
{zero_start_index_M[group_index], N, 1});
}
}
Expand Down Expand Up @@ -530,11 +530,8 @@ std::tuple<at::Tensor, std::vector<at::Tensor>> dispatch_fp8_grouped_kernel(
if (kernel == KernelMode::Small) {
return f8f8bf16_rowwise_grouped_impl<64, 128, 128, 2, 1, 1, true>(
XQ, WQ, x_scale, w_scale, output, zero_start_index_M);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_rowwise_grouped_impl<128, 128, 128, 2, 1, 1, true>(
XQ, WQ, x_scale, w_scale, output, zero_start_index_M);
} else {
return f8f8bf16_rowwise_grouped_impl<128, 128, 128, 1, 2, 1, true>(
return f8f8bf16_rowwise_grouped_impl<128, 256, 64, 2, 1, 1, false>(
XQ, WQ, x_scale, w_scale, output, zero_start_index_M);
}
}
Expand Down

0 comments on commit 98d54f7

Please sign in to comment.