Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mx_fp8_bf16 kernel #1637

Open
wants to merge 1 commit into
base: drisspg/stack/30
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,7 @@ def get_extensions():
extra_link_args = []
extra_compile_args = {
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
],
"nvcc": ["-O3" if not debug_mode else "-O0", "-t=0", "-std=c++17"],
}

if not IS_WINDOWS:
Expand Down Expand Up @@ -257,12 +254,16 @@ def get_extensions():
use_cutlass = True
cutlass_dir = os.path.join(third_party_path, "cutlass")
cutlass_include_dir = os.path.join(cutlass_dir, "include")
cutlass_tools_include_dir = os.path.join(
cutlass_dir, "tools", "util", "include"
)
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
if use_cutlass:
extra_compile_args["nvcc"].extend(
[
"-DTORCHAO_USE_CUTLASS",
"-I" + cutlass_include_dir,
"-I" + cutlass_tools_include_dir,
"-I" + cutlass_extensions_include_dir,
]
)
Expand Down
102 changes: 102 additions & 0 deletions test/prototype/mx_formats/test_mx_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import pytest
import torch

from torchao.float8.float8_utils import compute_error
from torchao.ops import mx_fp8_bf16
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
is_sm_at_least_100,
)

if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


def run_matrix_test(M: int, K: int, N: int) -> float:
"""
Run matrix multiplication test with given dimensions.

Args:
M, K, N: Matrix dimensions

Returns:
float: SQNR (Signal-to-Quantization-Noise Ratio) value
"""
dtype = torch.bfloat16
device = torch.device("cuda")

# Initialize matrices
a = torch.rand((M, K), dtype=dtype, device=device)
b = torch.rand((N, K), dtype=dtype, device=device)

# Convert to MX format
a_mx = MXTensor.to_mx(a, torch.float8_e4m3fn, 32)
b_mx = MXTensor.to_mx(b, torch.float8_e4m3fn, 32)

a_fp8 = a_mx._data
b_fp8 = b_mx._data
assert b_fp8.is_contiguous()
b_fp8 = b_fp8.transpose(-1, -2)

# Get scales
a_scale_e8 = a_mx._scale_e8m0.view(M, K // 32)
b_scale_e8 = b_mx._scale_e8m0.view(N, K // 32)

a_scale_block = to_blocked(a_scale_e8)
b_scale_block = to_blocked(b_scale_e8)

# Get reference output
out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose(
-1, -2
)

# Run implementation
out_e8_fp8 = mx_fp8_bf16(a_fp8, b_fp8, a_scale_block, b_scale_block)

# Calculate metrics
sqnr = compute_error(out_hp, out_e8_fp8)

return sqnr.item()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
)
@pytest.mark.parametrize(
"size",
[
# Small matrices
(128, 128, 128),
(256, 256, 256),
(384, 384, 384),
# Medium matrices
(512, 512, 512),
(640, 640, 640),
(768, 768, 768),
# Large matrices
(896, 896, 896),
(1024, 1024, 1024),
# Very large matrices
(8192, 8192, 8192),
# Non-square matrices
(128, 256, 384),
(256, 384, 512),
(384, 512, 640),
# Non-aligned matrices
(129, 256, 384),
(256, 384, 536),
(133, 512, 528),
],
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}",
)
def test_matrix_multiplication(size):
"""
Test matrix multiplication with various dimensions.
Verifies that the SQNR meets minimum quality threshold.
"""
M, K, N = size
sqnr = run_matrix_test(M, K, N)
assert sqnr >= 80.0, f"SQNR {sqnr} below threshold for dims {M}x{K}x{N}"
251 changes: 251 additions & 0 deletions torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
#include <torch/library.h>

#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAUtils.h>
#include <c10/util/Exception.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>

#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
#define BUILD_MX_KERNELS_CUTLASS
#endif

#if defined(BUILD_MX_KERNELS_CUTLASS)

#include "cute/tensor.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/util/packed_stride.hpp"


#endif

namespace torchao {

#if defined(BUILD_MX_KERNELS_CUTLASS)
namespace {

using namespace cute;

template<typename Element>
constexpr int GetAlignment() {
if constexpr (std::is_same_v<Element, cutlass::nv_float4_t<cutlass::float_e2m1_t>>)
return 32;
return 16;
}

template <typename ElementA,
typename ElementB,
typename ElementD,
typename MmaTileShape,
typename ClusterShape,
typename PerSmTileShape_MNK>
void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale,
at::Tensor& b_scale, at::Tensor& out) {
int M = a.size(0);
int K = a.size(1);
int N = b.size(1);

// A matrix configuration
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = GetAlignment<ElementA>(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)

// B matrix configuration
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = GetAlignment<ElementB>(); // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)

// C/D matrix configuration
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag


using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
PerSmTileShape_MNK, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;

// Initialize strides using packed stride configuration
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, make_shape(M, K, 1));
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, make_shape(N, K, 1));
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, make_shape(M, N, 1));

// Initialize scale factor layouts using block scaled configuration
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));

using DtypeA = typename ElementA::DataType;
using DtypeB = typename ElementB::DataType;
using DtypeScaleA = typename ElementA::ScaleFactorType;
using DtypeScaleB = typename ElementB::ScaleFactorType;
using DtypeOut = ElementD;

Gemm gemm;

auto A_ptr = reinterpret_cast<DtypeA*>(a.data_ptr());
auto B_ptr = reinterpret_cast<DtypeB*>(b.data_ptr());
auto SFA_ptr = reinterpret_cast<DtypeScaleA*>(a_scale.data_ptr());
auto SFB_ptr = reinterpret_cast<DtypeScaleB*>(b_scale.data_ptr());
auto out_ptr = reinterpret_cast<DtypeOut*>(out.data_ptr());

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K, 1},
{ // Mainloop arguments
A_ptr, stride_A,
B_ptr, stride_B,
SFA_ptr, layout_SFA,
SFB_ptr, layout_SFB
},
{ // Epilogue arguments
{1.0, 0.0},
nullptr, StrideC{}, // No bias for now
out_ptr, stride_D
}
};

// arguments.scheduler.max_swizzle_size = 8;

// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot implement");
// Allocate workspace memory
size_t workspace_size = Gemm::get_workspace_size(arguments);
auto workspace = a.new_empty(
{static_cast<int64_t>(workspace_size)},
at::TensorOptions().dtype(at::kByte));


// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot initialize");

status = gemm.run(at::cuda::getCurrentCUDAStream());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot run", cutlass::cutlassGetStatusString(status));

C10_CUDA_KERNEL_LAUNCH_CHECK();

}
}
#endif

void validate(at::Tensor a, at::Tensor b, at::Tensor a_scale, at::Tensor b_scale){
TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor");
TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor");
TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor");
TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor");

// Check matrix dimensions
TORCH_CHECK(a.dim() == 2, "a must be a matrix");
TORCH_CHECK(b.dim() == 2, "b must be a matrix");

// Get dimensions
auto M = a.size(0);
auto K = a.size(1);
auto N = b.size(1);

TORCH_CHECK(b.size(0) == K,
"Incompatible matrix dimensions: a is ", M, "x", K, " but b is ", b.size(0), "x", N);

// Needed for TMA store
TORCH_CHECK(N % 8 == 0, "N must be a multiple of 16 but got, ", N);

// Check 16-byte alignment for input tensors
TORCH_CHECK(
reinterpret_cast<std::uintptr_t>(a.data_ptr()) % 16 == 0,
"Input tensor 'a' must be 16-byte aligned");
TORCH_CHECK(
reinterpret_cast<std::uintptr_t>(b.data_ptr()) % 16 == 0,
"Input tensor 'b' must be 16-byte aligned");

auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; };
auto num_k_blocks = ceil_div(K, 32);
// For a_scale, we expect elements or M* ceil(K/32) elements
auto expected_a_scale_size = 128 * ceil_div(M, 128) * num_k_blocks;
TORCH_CHECK(a_scale.numel() == expected_a_scale_size, "Expected b_scale_size to be ", expected_a_scale_size, " but got ", a_scale.numel());

// For b_scale, we expect N * ceil(K/32) elements
auto expected_b_scale_size = 128 * ceil_div(N, 128) * num_k_blocks;
TORCH_CHECK(b_scale.numel() == expected_b_scale_size, "Expected a_scale_size to be ", expected_b_scale_size, " but got ", b_scale.numel());

// Check tensor strides for optimal memory layout
TORCH_CHECK(
a.stride(1) == 1,
"Input tensor 'a' must be contiguous in the K dimension (row-major)");
TORCH_CHECK(
b.stride(0) == 1,
"Input tensor 'b' must be contiguous in the K dimension (column-major)");
}


at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
at::Tensor b_scale) {
#if defined(BUILD_MX_KERNELS_CUTLASS)
validate(a, b, a_scale, b_scale);

auto out =
at::empty({a.size(0), b.size(1)}, a.options().dtype(at::kBFloat16));
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
using ElementD = cutlass::bfloat16_t;

using MmaTileShape = Shape<_128,_128,_128>;
using ClusterShape = Shape<_2,_1,_1>;
using PerSmTileShape_MNK = Shape<_128,_128,_128>;

run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out);
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
return at::Tensor{};
#endif
}

TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16);
}

} // namespace torchao
Loading
Loading