From ae511474c99086d3c0e11a0691ee4b68cd612ba6 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 28 Jan 2025 21:03:44 -0800 Subject: [PATCH] Add mx_fp8_bf16 kernel stack-info: PR: https://github.com/pytorch/ao/pull/1637, branch: drisspg/stack/31 --- setup.py | 6 +- torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu | 205 ++++++++++++++++++++ torchao/ops.py | 29 +++ 3 files changed, 239 insertions(+), 1 deletion(-) create mode 100644 torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu diff --git a/setup.py b/setup.py index 8628dc7ef4..a9f52e0b15 100644 --- a/setup.py +++ b/setup.py @@ -218,6 +218,7 @@ def get_extensions(): "nvcc": [ "-O3" if not debug_mode else "-O0", "-t=0", + "-std=c++20" ], } @@ -243,13 +244,16 @@ def get_extensions(): use_cutlass = False if use_cuda and not IS_WINDOWS: use_cutlass = True + + if use_cutlass: cutlass_dir = os.path.join(third_party_path, "cutlass") cutlass_include_dir = os.path.join(cutlass_dir, "include") - if use_cutlass: + cutlass_tools_include_dir = os.path.join(cutlass_dir, "tools", "util", "include") extra_compile_args["nvcc"].extend( [ "-DTORCHAO_USE_CUTLASS", "-I" + cutlass_include_dir, + "-I" + cutlass_tools_include_dir, ] ) diff --git a/torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu b/torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu new file mode 100644 index 0000000000..c3232e4dc9 --- /dev/null +++ b/torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu @@ -0,0 +1,205 @@ +#include + +#include +#include +#include +#include +#include +#include + +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) +#define BUILD_MX_KERNELS_CUTLASS +#endif + +#if defined(BUILD_MX_KERNELS_CUTLASS) + +#include "cute/tensor.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/util/packed_stride.hpp" + + +#endif + +namespace torchao { + +#if defined(BUILD_MX_KERNELS_CUTLASS) +namespace { + +using namespace cute; + +template +constexpr int GetAlignment() { + if constexpr (std::is_same_v>) + return 32; + return 16; +} + +template +void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale, + at::Tensor& b_scale, at::Tensor& out) { + int M = a.size(0); + int K = a.size(1); + int N = b.size(1); + + // A matrix configuration + using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand + constexpr int AlignmentA = GetAlignment(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand + using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand + using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand + constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + // Kernel functional config + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + PerSmTileShape_MNK, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Reference device GEMM implementation type + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + + // Initialize strides using packed stride configuration + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, make_shape(M, K, 1)); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, make_shape(N, K, 1)); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, make_shape(M, N, 1)); + + // Initialize scale factor layouts using block scaled configuration + auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); + auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); + + using DtypeA = ElementA::DataType; + using DtypeB = ElementB::DataType; + using DtypeScaleA = ElementA::ScaleFactorType; + using DtypeScaleB = ElementB::ScaleFactorType; + using DtypeOut = ElementD; + + Gemm gemm; + + auto A_ptr = reinterpret_cast(a.data_ptr()); + auto B_ptr = reinterpret_cast(b.data_ptr()); + auto SFA_ptr = reinterpret_cast(a_scale.data_ptr()); + auto SFB_ptr = reinterpret_cast(b_scale.data_ptr()); + auto out_ptr = reinterpret_cast(out.data_ptr()); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + { // Mainloop arguments + A_ptr, stride_A, + B_ptr, stride_B, + SFA_ptr, layout_SFA, + SFB_ptr, layout_SFB + }, + { // Epilogue arguments + {1.0, 0.0}, + nullptr, StrideC{}, // No bias for now + out_ptr, stride_D + } + }; + + // arguments.scheduler.max_swizzle_size = 8; + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot implement"); + // Allocate workspace memory + size_t workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = a.new_empty( + {static_cast(workspace_size)}, + at::TensorOptions().dtype(at::kByte)); + + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.data_ptr()); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot initialize"); + + status = gemm.run(at::cuda::getCurrentCUDAStream()); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot run", cutlass::cutlassGetStatusString(status)); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + +} +} +#endif + +at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale, + at::Tensor b_scale) { +#if defined(BUILD_MX_KERNELS_CUTLASS) + TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor"); + TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor"); + TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor"); + TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor"); + + auto out = + at::empty({a.size(0), b.size(1)}, a.options().dtype(at::kBFloat16)); + using ElementA = cutlass::mx_float8_t; + using ElementB = cutlass::mx_float8_t; + using ElementD = cutlass::bfloat16_t; + + using MmaTileShape = Shape<_256,_256,_256>; + using ClusterShape = Shape<_4,_4,_1>; + using PerSmTileShape_MNK = Shape<_128,_256,_256>; + + run_gemm(a, b, a_scale, b_scale, out); + return out; + #else + TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); + return at::Tensor{}; +#endif +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16); +} + +} // namespace torchao diff --git a/torchao/ops.py b/torchao/ops.py index f4b55c4951..6a7f47e143 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -22,6 +22,9 @@ lib.define( "s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) +lib.define( + "mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor" +) def register_custom_op(name): @@ -615,3 +618,29 @@ def _( dtype=input_scale.dtype, device=input.device, ) + + +def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): + """Defines a matmul between two fp8 tensors w/ MX scales in E8MO and returns a bf16 tensor. + + Note: The mx scales are E8MO tensors store in uint8 tensors (for now). + The layout of the scales is very particular, see: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + A: fp8 tensor + B: fp8 tensor + A_scale: E8M0 scale tensor for A with groupsize=32 in swizzled layout + B_scale: E8M0 scale tensor for B with groupsize=32 in swizzled layout + + Returns: + MXN bf16 Tensor + + """ + return torch.ops.torchao.mx_fp8_bf16.default(A, B, A_scale, B_scale) + + +@register_custom_op("torchao::mx_fp8_bf16") +def meta_mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): + """Meta impl for mx_fp8_bf16""" + return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device)