Skip to content

Commit

Permalink
Add mx_fp8_bf16 kernel
Browse files Browse the repository at this point in the history
stack-info: PR: #1637, branch: drisspg/stack/31
  • Loading branch information
drisspg committed Jan 29, 2025
1 parent cef8f5f commit ae51147
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 1 deletion.
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def get_extensions():
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
"-std=c++20"
],
}

Expand All @@ -243,13 +244,16 @@ def get_extensions():
use_cutlass = False
if use_cuda and not IS_WINDOWS:
use_cutlass = True

if use_cutlass:
cutlass_dir = os.path.join(third_party_path, "cutlass")
cutlass_include_dir = os.path.join(cutlass_dir, "include")
if use_cutlass:
cutlass_tools_include_dir = os.path.join(cutlass_dir, "tools", "util", "include")
extra_compile_args["nvcc"].extend(
[
"-DTORCHAO_USE_CUTLASS",
"-I" + cutlass_include_dir,
"-I" + cutlass_tools_include_dir,
]
)

Expand Down
205 changes: 205 additions & 0 deletions torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
#include <torch/library.h>

#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAUtils.h>
#include <c10/util/Exception.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>

#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
#define BUILD_MX_KERNELS_CUTLASS
#endif

#if defined(BUILD_MX_KERNELS_CUTLASS)

#include "cute/tensor.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/util/packed_stride.hpp"


#endif

namespace torchao {

#if defined(BUILD_MX_KERNELS_CUTLASS)
namespace {

using namespace cute;

template<typename Element>
constexpr int GetAlignment() {
if constexpr (std::is_same_v<Element, cutlass::nv_float4_t<cutlass::float_e2m1_t>>)
return 32;
return 16;
}

template <typename ElementA,
typename ElementB,
typename ElementD,
typename MmaTileShape,
typename ClusterShape,
typename PerSmTileShape_MNK>
void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale,
at::Tensor& b_scale, at::Tensor& out) {
int M = a.size(0);
int K = a.size(1);
int N = b.size(1);

// A matrix configuration
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = GetAlignment<ElementA>(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)

// B matrix configuration
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)

// C/D matrix configuration
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag


using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
PerSmTileShape_MNK, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;

// Initialize strides using packed stride configuration
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, make_shape(M, K, 1));
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, make_shape(N, K, 1));
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, make_shape(M, N, 1));

// Initialize scale factor layouts using block scaled configuration
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));

using DtypeA = ElementA::DataType;
using DtypeB = ElementB::DataType;
using DtypeScaleA = ElementA::ScaleFactorType;
using DtypeScaleB = ElementB::ScaleFactorType;
using DtypeOut = ElementD;

Gemm gemm;

auto A_ptr = reinterpret_cast<DtypeA*>(a.data_ptr());
auto B_ptr = reinterpret_cast<DtypeB*>(b.data_ptr());
auto SFA_ptr = reinterpret_cast<DtypeScaleA*>(a_scale.data_ptr());
auto SFB_ptr = reinterpret_cast<DtypeScaleB*>(b_scale.data_ptr());
auto out_ptr = reinterpret_cast<DtypeOut*>(out.data_ptr());

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K, 1},
{ // Mainloop arguments
A_ptr, stride_A,
B_ptr, stride_B,
SFA_ptr, layout_SFA,
SFB_ptr, layout_SFB
},
{ // Epilogue arguments
{1.0, 0.0},
nullptr, StrideC{}, // No bias for now
out_ptr, stride_D
}
};

// arguments.scheduler.max_swizzle_size = 8;

// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot implement");
// Allocate workspace memory
size_t workspace_size = Gemm::get_workspace_size(arguments);
auto workspace = a.new_empty(
{static_cast<int64_t>(workspace_size)},
at::TensorOptions().dtype(at::kByte));


// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot initialize");

status = gemm.run(at::cuda::getCurrentCUDAStream());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot run", cutlass::cutlassGetStatusString(status));

C10_CUDA_KERNEL_LAUNCH_CHECK();

}
}
#endif

at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
at::Tensor b_scale) {
#if defined(BUILD_MX_KERNELS_CUTLASS)
TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor");
TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor");
TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor");
TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor");

auto out =
at::empty({a.size(0), b.size(1)}, a.options().dtype(at::kBFloat16));
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
using ElementD = cutlass::bfloat16_t;

using MmaTileShape = Shape<_256,_256,_256>;
using ClusterShape = Shape<_4,_4,_1>;
using PerSmTileShape_MNK = Shape<_128,_256,_256>;

run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out);
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
return at::Tensor{};
#endif
}

TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16);
}

} // namespace torchao
29 changes: 29 additions & 0 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
lib.define(
"s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
)
lib.define(
"mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor"
)


def register_custom_op(name):
Expand Down Expand Up @@ -615,3 +618,29 @@ def _(
dtype=input_scale.dtype,
device=input.device,
)


def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
"""Defines a matmul between two fp8 tensors w/ MX scales in E8MO and returns a bf16 tensor.
Note: The mx scales are E8MO tensors store in uint8 tensors (for now).
The layout of the scales is very particular, see:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
A: fp8 tensor
B: fp8 tensor
A_scale: E8M0 scale tensor for A with groupsize=32 in swizzled layout
B_scale: E8M0 scale tensor for B with groupsize=32 in swizzled layout
Returns:
MXN bf16 Tensor
"""
return torch.ops.torchao.mx_fp8_bf16.default(A, B, A_scale, B_scale)


@register_custom_op("torchao::mx_fp8_bf16")
def meta_mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
"""Meta impl for mx_fp8_bf16"""
return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device)

0 comments on commit ae51147

Please sign in to comment.