Skip to content

Commit

Permalink
Fix handling of dynamic FP8 grouped gemm on Nvidia (#3616)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#695

Pull Request resolved: #3616

This diff is the nvidia mirror of D68686266, which changes dynamic grouped gemm to return a tensor of shape [total_M, N] when zero_start_index_M isnt provided.  We also add appropriate tests to make sure the behavior doesnt break going forward.

Reviewed By: jasonjk-park, jianyuh, jiawenliu64

Differential Revision: D68689077

fbshipit-source-id: f60b533e6ec90b753dc15f2136c7cef6e162bf1c
  • Loading branch information
jwfromm authored and facebook-github-bot committed Feb 3, 2025
1 parent 98d54f7 commit ecb19d9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,6 @@ std::tuple<at::Tensor, std::vector<at::Tensor>> f8f8bf16_rowwise_grouped_impl(
reinterpret_cast<GroupedGemmArgs::ElementOutput**>(output_ptr),
stride_output_ptr}};

int M = XQ[0].size(0);
int N = WQ[0].size(0);
arguments.epilogue.thread = {
{reinterpret_cast<const GroupedGemmArgs::ElementComputeEpilogue**>(
x_scale_ptr)}, // x_scale
Expand Down Expand Up @@ -599,7 +597,13 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
at::Tensor output = std::get<0>(dispatch_fp8_grouped_kernel(
XQ, WQ, x_scale, w_scale, Y, zero_start_index_M));
// View as proper shape.
output = output.view({-1, XQ[0].size(0), WQ[0].size(0)});
// When zero_start_index_M is provided, we can view as [G, M, N]
if (zero_start_index_M.has_value()) {
output = output.view({-1, XQ[0].size(0), WQ[0].size(0)});
// Otherwise we view as {total_M, N}.
} else {
output = output.view({-1, WQ[0].size(0)});
}
return output;
}

Expand Down
87 changes: 32 additions & 55 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,8 @@ def fp8_loopover_bmm(
torch.testing.assert_close(y_ref, y_fp8, atol=8.0e-2, rtol=8.0e-2)

@unittest.skipIf(
not torch.version.cuda, "Skip on AMD: GMM ops are not yet suported."
not torch.version.cuda and torch.version.hip < "6.2",
"Skip on AMD with < RoCM 6.2",
)
@settings(deadline=None)
@given(
Expand Down Expand Up @@ -805,63 +806,39 @@ def test_fp8_grouped_gemm(
w_scale_group = torch.unbind(torch.stack(w_scale_group, dim=0).contiguous())

# FP8 grouped gemm kernel
fp8_args = (
[
xq_group,
wq_group,
x_scale_group,
w_scale_group,
zero_start_index_M if use_padding_zeros else None,
]
if use_dynamic
else [xq_group, wq_group, x_scale_group, w_scale_group]
)
fp8_op = (
torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic
if use_dynamic
else torch.ops.fbgemm.f8f8bf16_rowwise_grouped
)
if use_cudagraph:
if use_padding_zeros:
# warmup
torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
xq_group,
wq_group,
x_scale_group,
w_scale_group,
zero_start_index_M,
)
# With cudagraph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
xq_group,
wq_group,
x_scale_group,
w_scale_group,
zero_start_index_M,
)
g.replay()
y_fp8_group = y_fp8_group.unbind(dim=0)
else:
# warmup
torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
xq_group,
wq_group,
x_scale_group,
w_scale_group,
)
# With cudagraph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
xq_group,
wq_group,
x_scale_group,
w_scale_group,
)
g.replay()
# warmup
fp8_op(*fp8_args)
# With cudagraph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y_fp8_group = fp8_op(*fp8_args)
g.replay()
else:
if use_padding_zeros:
y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
xq_group,
wq_group,
x_scale_group,
w_scale_group,
zero_start_index_M,
)
y_fp8_group = y_fp8_group.unbind(dim=0)
y_fp8_group = fp8_op(*fp8_args)

# Massage output into proper format.
if not isinstance(y_fp8_group, (tuple, list)):
if y_fp8_group.ndim == 2:
y_fp8_group = torch.split(y_fp8_group, tuple(ms.tolist()), dim=0)
else:
y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
xq_group,
wq_group,
x_scale_group,
w_scale_group,
)
y_fp8_group = torch.unbind(y_fp8_group)

# BF16 grouped gemm kernel
bf16_args = (
Expand Down

0 comments on commit ecb19d9

Please sign in to comment.