diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c67bdd333..8744e0a6c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # NVIDIA CUTLASS Changelog + +## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23) +* [CuTe](/media/docs/cute/00_quickstart.md), a [new core library and backend](/include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors. +* [A new conceptual operation hierarchy](media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](/media/docs/gemm_api_3x.md). +* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](media/docs/cutlass_3x_backwards_compatibility.md). +* Updates to [Functionality](media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3. +* Updates to [Compatibility](/README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](/README.md#Target-Architecture). +* New warp-specialized GEMM [kernel schedules](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. +* Extensions to CUTLASS profiler to support threadblock cluster shapes in library and profiler tile configurations. +* [CUTLASS library integration](/tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler. +* Support for [Hopper GEMMs](examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features. +* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](examples/48_hopper_warp_specialized_gemm), [49](examples/49_hopper_gemm_schedules_with_collective_builder), and [50](examples/50_hopper_gemm_with_epilogue_swizzle). + ## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19) * [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one. * [Fused multi-head attention Kernel](/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel. diff --git a/CITATION.cff b/CITATION.cff index 7ae2b4b1ce..ea97f1f68e 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -5,33 +5,61 @@ message: >- following metadata. type: software authors: - - given-names: Andrew - email: akerr@nvidia.com - family-names: Kerr + - given-names: Vijay + family-names: Thakkar + email: vithakkar@nvidia.com + affiliation: NVIDIA + - given-names: Pradeep + family-names: Ramani + email: prramani@nvidia.com + affiliation: NVIDIA + - given-names: Cris + family-names: Cecka + email: ccecka@nvidia.com + affiliation: NVIDIA + - given-names: Aniket + family-names: Shivam + email: ashivam@nvidia.com + affiliation: NVIDIA + - given-names: Honghao + family-names: Lu + email: honghaol@nvidia.com + affiliation: NVIDIA + - given-names: Ethan + family-names: Yan + email: etyan@nvidia.com + affiliation: NVIDIA + - given-names: Jack + family-names: Kosaian + email: jkosaian@nvidia.com + affiliation: NVIDIA + - given-names: Mark + family-names: Hoemmen + email: mhoemmen@nvidia.com affiliation: NVIDIA - given-names: Haicheng family-names: Wu - affiliation: NVIDIA email: haichengw@nvidia.com - - given-names: Manish - family-names: Gupta - affiliation: Google - email: manigupta@google.com - - given-names: Dustyn - family-names: Blasig - email: dblasig@nvidia.com affiliation: NVIDIA - - given-names: Pradeep - family-names: Ramini - email: prramani@nvidia.com + - given-names: Andrew + family-names: Kerr + email: akerr@nvidia.com + affiliation: NVIDIA + - given-names: Matt + family-names: Nicely + email: mnicely@nvidia.com affiliation: NVIDIA - given-names: Duane family-names: Merrill email: dumerrill@nvidia.com affiliation: NVIDIA - - given-names: Aniket - family-names: Shivam - email: ashivam@nvidia.com + - given-names: Dustyn + family-names: Blasig + email: dblasig@nvidia.com + affiliation: NVIDIA + - given-names: Fengqi + family-names: Qiao + email: fqiao@nvidia.com affiliation: NVIDIA - given-names: Piotr family-names: Majcher @@ -49,10 +77,12 @@ authors: family-names: Wang email: jinw@nvidia.com affiliation: NVIDIA - - given-names: Matt - family-names: Nicely - email: mnicely@nvidia.com - affiliation: NVIDIA + - given-names: Manish + family-names: Gupta + affiliation: Google + email: manigupta@google.com + + repository-code: 'https://github.com/NVIDIA/cutlass' abstract: >- CUTLASS is a collection of CUDA C++ template @@ -71,12 +101,12 @@ abstract: >- flexibility simplifies their use as building blocks within custom kernels and applications. keywords: - - 'cutlass, tensor cores, cuda' + - 'cutlass, tensor cores, cuda, cute, nvidia, gpu, linear algebra, matrix computations' license: BSD-3-Clause -license-url: https://github.com/NVIDIA/cutlass/blob/v2.11.0/LICENSE.txt -version: '2.11.0' -date-released: '2022-11-19' +license-url: https://github.com/NVIDIA/cutlass/blob/v3.0.0/LICENSE.txt +version: '3.0.0' +date-released: '2023-01-23' identifiers: - type: url - value: "https://github.com/NVIDIA/cutlass/tree/v2.11.0" - description: The GitHub release URL of tag 2.11.0 + value: "https://github.com/NVIDIA/cutlass/tree/v3.0.0" + description: The GitHub release URL of tag 3.0.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index 2b8d7c8225..e879f780c3 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,7 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -cmake_minimum_required(VERSION 3.12.4 FATAL_ERROR) +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) if(cutlass_LOADED) # If CUTLASS has been previously fetched and loaded, don't do it again. @@ -39,35 +39,40 @@ endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set") -project(CUTLASS VERSION 2.11.0 LANGUAGES CXX) +project(CUTLASS VERSION 3.0.0 LANGUAGES CXX) include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) -if (CUDA_VERSION VERSION_LESS 10.2) - message(WARNING "CUTLASS ${CUTLASS_VERSION} requires CUDA 10.2 or higher, and strongly recommends CUDA 11.0 or higher.") -elseif (CUDA_VERSION VERSION_LESS 11.0) - message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.0 or higher.") +if (CUDA_VERSION VERSION_LESS 11.3) + message(WARNING "CUTLASS ${CUTLASS_VERSION} requires CUDA 11.4 or higher, and strongly recommends CUDA 11.8 or higher.") +elseif (CUDA_VERSION VERSION_LESS 11.4) + message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.8 or higher.") +endif() + +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.5) + message(FATAL_ERROR "GCC version must be at least 7.5!") +endif() + +if (CUDA_COMPILER MATCHES "[Cc]lang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0) + message(FATAL_ERROR "Clang 7.0+ required for GPU compilation") endif() find_package(Doxygen QUIET) # -# CUTLASS 2.x requires C++11 +# CUTLASS 3.x requires C++17 # -if (NOT IMPLICIT_CMAKE_CXX_STANDARD) - set(CMAKE_CXX_STANDARD 11) - set(CMAKE_CXX_STANDARD_REQUIRED ON) - set(CMAKE_CXX_EXTENSIONS OFF) -endif() +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) if(CUTLASS_NATIVE_CUDA) - set(CMAKE_CUDA_STANDARD 11) + set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS --expt-relaxed-constexpr) else() - if (NOT IMPLICIT_CMAKE_CXX_STANDARD) - list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++11) - endif() + list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++17) endif() - + if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) set(CMAKE_INSTALL_PREFIX install CACHE PATH "Default installation location." FORCE) endif() @@ -107,29 +112,14 @@ if (CUTLASS_ENABLE_TESTS) endif() set(CUTLASS_NVCC_ARCHS_SUPPORTED "") -if (NOT CUDA_VERSION VERSION_LESS 7.5) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 53) -endif() -if (NOT CUDA_VERSION VERSION_LESS 8.0) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 60 61) -endif() -if (NOT CUDA_VERSION VERSION_LESS 9.0) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 70) +if (CUDA_VERSION VERSION_GREATER_EQUAL 11.4 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 70 72 75 80 86 87) endif() -if (NOT CUDA_VERSION VERSION_LESS 9.2) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 72) +if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 89 90) endif() -if (NOT CUDA_VERSION VERSION_LESS 10.0) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 75) -endif() -if (NOT CUDA_VERSION VERSION_LESS 11.0) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80) -endif() -if (NOT CUDA_VERSION VERSION_LESS 11.1 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 86) -endif() -if (NOT CUDA_VERSION VERSION_LESS 11.8 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90) +if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90a) endif() set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.") @@ -271,6 +261,7 @@ if (CUTLASS_ENABLE_TENSOR_CORE_MMA) list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1) endif() + if (NOT MSVC AND CUTLASS_NVCC_KEEP) # MSVC flow handles caching already, but for other generators we handle it here. set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files") @@ -288,6 +279,15 @@ if (CUTLASS_ENABLE_F16C AND NOT CMAKE_CROSSCOMPILING) endif() endif() +if (CUTLASS_ENABLE_OPENMP_TESTS) + find_package(OpenMP) + if(OpenMP_CXX_FOUND) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=${OpenMP_CXX_FLAGS}) + else() + message(WARNING "CUTLASS_ENABLE_OPENMP_TESTS set but OpenMP not found.") + endif() +endif() + list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$:-Xcompiler=-Wconversion>) list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$:-Xcompiler=-fno-strict-aliasing>) @@ -313,10 +313,6 @@ if(CUDA_COMPILER MATCHES "[Cc]lang") message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" ) endif() - if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0) - message(FATAL_ERROR "Clang 7.0+ required for GPU compilation") - endif() - # There are numerous Clang versions that can work with each CUDA toolkit and the # the checks are not very useful so we are turning them off and using testing to # ensure the various combinations work properly. @@ -341,6 +337,7 @@ if(CUDA_COMPILER MATCHES "[Cc]lang") list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wl,--disable-new-dtags) link_libraries(nvidia::cudart) + link_libraries(nvidia::cuda_driver) endif() # Support for 128-bit integers if using NVIDIA C++ compiler @@ -530,6 +527,8 @@ target_include_directories( $ $ $ + $ + $ ) install( diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 21357b5f52..5a159d8c57 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -7,63 +7,77 @@ This is the official list of CUTLASS developers and contributors. ## DEVELOPERS -Andrew Kerr -Haicheng Wu -Manish Gupta -Dustyn Blasig -Pradeep Ramani -Cris Cecka -Vijay Thakkar -Aniket Shivam -Honghao Lu -Ethan Yan -Zhaodong Chen -Jack Kosaian -Yujia Zhai -Naila Farooqui -Piotr Majcher -Paul Springer -Jin Wang -Chinmay Talegaonkar -Shang Zhang -Scott Yokim -Markus Hohnerbach -Aditya Atluri -David Tanner -Manikandan Ananth +Vijay Thakkar
+Pradeep Ramani
+Cris Cecka
+Aniket Shivam
+Jack Kosaian
+Mark Hoemmen
+Honghao Lu
+Ethan Yan
+Haicheng Wu
+Andrew Kerr
+Dustyn Blasig
+Fengqi Qiao
+Duane Merrill
+Yujia Zhai
+Shang Zhang
+Piotr Majcher
+Paul Springer
+Markus Hohnerbach
+Jin Wang
+Aditya Atluri
+ +## CuTe +Cris Cecka
+Vijay Thakkar
## CUTLASS Product Manager -Matthew Nicely - +Matthew Nicely
+ +## Former CUTLASS Developers +Manish Gupta
+Naila Farooqui
+David Tanner
+Manikandan Ananth
+Zhaodong Chen
+Chinmay Talegaonkar
+ ## CONTRIBUTORS -Timothy Costa -Julien Demouth -Brian Fahs -Michael Goldfarb -Mostafa Hagog -Fei Hu -Alan Kaatz -Tina Li -Timmy Liu -Duane Merrill -Kevin Siu -Markus Tavenrath -John Tran -Vicki Wang -Junkai Wu -Fung Xie -Albert Xu -Jack Yang -Xiuxia Zhang -Nick Zhao +Timothy Costa
+Julien Demouth
+Brian Fahs
+Michael Garland
+Michael Goldfarb
+Mostafa Hagog
+Fei Hu
+Alan Kaatz
+Tina Li
+Timmy Liu
+Wei Liu
+Duane Merrill
+Kevin Siu
+Markus Tavenrath
+John Tran
+Vicki Wang
+Junkai Wu
+Fung Xie
+Albert Xu
+Yang Xu
+Jack Yang
+Scott Yokim
+Xiuxia Zhang
+Nick Zhao
## ACKNOWLEDGEMENTS -Girish Bharambe -Luke Durant -Olivier Giroux -Stephen Jones -Rishkul Kulkarni -Bryce Lelbach -Joel McCormack -Kyrylo Perelygin +Girish Bharambe
+Luke Durant
+Carter Edwards
+Olivier Giroux
+Stephen Jones
+Rishkul Kulkarni
+Bryce Lelbach
+Joel McCormack
+Kyrylo Perelygin
+Sean Treichler
diff --git a/README.md b/README.md index b58465132f..a89b8f49b4 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,18 @@ ![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 2.11 +# CUTLASS 3.0 -_CUTLASS 2.11 - November 2022_ +_CUTLASS 3.0 - January 2023_ CUTLASS is a collection of CUDA C++ template abstractions for implementing -high-performance matrix-multiplication (GEMM) and related computations at all levels +high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. It incorporates strategies for hierarchical decomposition and data movement similar to those used to implement cuBLAS and cuDNN. CUTLASS decomposes these "moving parts" into reusable, modular software components abstracted by C++ template -classes. These thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized -and tuned via custom tiling sizes, data types, and other algorithmic policy. The -resulting flexibility simplifies their use as building blocks within custom kernels -and applications. +classes. Primitives for different levels of a conceptual parallelization hierarchy +can be specialized and tuned via custom tiling sizes, data types, +and other algorithmic policy. The resulting flexibility simplifies their use +as building blocks within custom kernels and applications. To support a wide variety of applications, CUTLASS provides extensive support for mixed-precision computations, providing specialized data-movement and @@ -21,60 +21,75 @@ point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32), single-precision floating point (FP32), [FP32 emulation via tensor core instruction](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm), double-precision floating -point (FP64) types, integer data types (4b and 8b), and binary data types (1b). -CUTLASS demonstrates warp-synchronous matrix multiply operations -targeting the programmable, high-throughput _Tensor Cores_ implemented by -NVIDIA's Volta, Turing, and Ampere architectures. - -CUTLASS implements high-performance Convolution via the implicit GEMM algorithm. -Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of -CUTLASS's modular GEMM pipeline. -This allows CUTLASS to build convolutions by reusing highly optimized warp-wide GEMM components and below. +point (FP64) types, integer data types (4b and 8b), and binary data types (1b). +CUTLASS demonstrates warp-synchronous matrix multiply operations +targeting the programmable, high-throughput _Tensor Cores_ implemented by +NVIDIA's Volta, Turing, Ampere, and Hopper architectures. See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly. See the [functionality listing](/media/docs/functionality.md) for the list of operations supported at each level of the execution model hierarchy. -# What's New in CUTLASS 2.11 - -CUTLASS 2.11 is an update to CUTLASS adding: -- [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one. -- [Fused multi-head attention kernel](/examples/41_fused_multi_head_attention). It has two variants: one for fixed sequence lengths, and another for variable sequence lengths. -- [Dual GEMM](/examples/45_dual_gemm). It can run two GEMMs that share the same left input matrix in one kernel. -- Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8. -- [BLAS3](/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions. -- [ELL Block Sparse GEMM](/examples/43_ell_block_sparse_gemm). -- [Optimized Group Conv](/examples/42_ampere_tensorop_group_conv). -- [Optimized DepthWise Conv](/examples/46_depthwise_simt_conv2dfprop). -- [Scripts](/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. -- [FP8 data type definition](/include/cutlass/float8.h) and [conversion routines](/include/cutlass/numeric_conversion.h#L1274-2115). -- Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers). -- **Deprecation announcement:** CUTLASS plans to deprecate the following in the next major release: - - Maxwell and Pascal GPU architectures - - Ubuntu 16.04 - - CUDA 10.2 - - C++ 11 -- **Future requirement announcement:** CUTLASS plans to add the following requirements in the next major release: - - Minimum C++ standard - C++17 +CUTLASS 3.0 introduces a new core library, CuTe, to describe and manipulate tensors of threads and data. +CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly package the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations. + +The core abstractions of CuTe are hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning. + +CUTLASS 3.0 adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design +and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](/media/docs/cute/00_quickstart.md). + +In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. + +# What's New in CUTLASS 3.0 + +CUTLASS 3.0, as the next major version of the CUTLASS API, brings with it CuTe, a new programming model and backend designed for massively parallel heterogenous agents. Using CuTe, CUTLASS 3.0 provides implementations of GEMM kernels for the NVIDIA Hopper architecture. + +- [CuTe-based layouts and layout algebra](/media/docs/cute/00_quickstart.md) +- [A new GEMM template API](/media/docs/gemm_api_3x.md) that eschews the architecture-centric hierarchy of 2.x in favour of a new conceptual framing. Read more in the [3.0 design documentation](/media/docs/cutlass_3x_design.md). +- Support for 4th generation Hopper Tensor Core instructions (WGMMA) through CuTe. +- Support for Hopper asynchronous Tensor Memory Accelerator (TMA) instructions and associated transaction barriers through CuTe. +- New warp-specialized GEMM kernels targeting Hopper TMA + WGMMA for speed-of-light GEMMs. +- New warp-specialized persistent GEMM kernels targeting Hopper TMA + WGMMA. +- Support for CUDA Threadblock Clusters and programmatic TMA multicast for greater execution and data locality. +- A new way to instantiate default GEMM kernels using `CollectiveBuilder`s that supersede the 2.x `DefaultXConfiguration` types in favour a metaprogramming based kernel generator functionality. See [example 49](/examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu). +- Extensions to the CUTLASS library and profiler to support CUTLASS 3.0 Hopper kernels, and a new format +for kernel procedural names. +- *Announcement*: CUTLASS plans to rename the GitHub branch `master` to `main` with a future release. + +## New architecture, compiler, and CUDA Toolkit requirements + +Minimum requirements: + +- Architecture: Volta +- Compiler: Must support at least C++17 +- CUDA Toolkit version: 11.4 + +CUTLASS 3.0 *removes support* for the following: + +- Maxwell and Pascal GPU architectures +- Ubuntu 16.04 +- CUDA 10.2 +- C++ language versions less than 17. **See the [CHANGELOG](CHANGELOG.md) for a detailed listing of releases and updates.** # Performance -

+

CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels, -they exhibit performance comparable to cuBLAS for scalar GEMM +they exhibit peak performance comparable to cuBLAS for scalar GEMM computations. The above figure shows CUTLASS performance relative to cuBLAS -for large matrix dimensions on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/), -an [NVIDIA A2](https://www.nvidia.com/en-us/data-center/products/a2/), -an [NVIDIA TitanV](https://www.nvidia.com/en-us/titan/titan-v/), -and an [NVIDIA GeForce 2080 Ti](https://www.nvidia.com/en-us/geforce/graphics-cards/rtx-2080-ti/) -compiled with the [CUDA 11.5 Toolkit](https://developer.nvidia.com/cuda-downloads). Tensor Core operations are implemented using CUDA's +for large matrix dimensions on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture), +an [NVIDIA L40](https://www.nvidia.com/en-us/data-center/l40/) (NVIDIA Ada architecture), +an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) (NVIDIA Ampere architecture), +and an [NVIDIA A40](https://www.nvidia.com/en-us/data-center/a40/) (NVIDIA Ampere architecture). +CUTLASS 3.0 was compiled with the [CUDA 12.0 Toolkit](https://developer.nvidia.com/cuda-downloads). +Tensor Core operations are implemented using CUDA's [mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma). -

+

When using CUTLASS building blocks to construct device-wide implicit gemm (Fprop, Dgrad, and Wgrad) kernels, CUTLASS performance is also comparable to cuDNN when running Resnet-50 layers on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) @@ -83,39 +98,48 @@ as shown in the above figure. Tensor Core operations are still implemented usin # Compatibility -CUTLASS requires a C++11 host compiler and performs best when built with the [**CUDA 11.8 Toolkit**](https://developer.nvidia.com/cuda-toolkit). - -It is also compatible with CUDA 11.x. +CUTLASS requires a C++17 host compiler and +performs best when built with the [**CUDA 12.0 Toolkit**](https://developer.nvidia.com/cuda-toolkit). +It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, and CUDA 11.8. ## Operating Systems We have tested the following environments. |**Operating System** | **Compiler** | |-----------------|----------| -| Windows 10 | Microsoft Visual Studio 2015| -| | Microsoft Visual Studio 2017| -| | Microsoft Visual Studio 2019| -| Ubuntu 18.04 | GCC 7.5.0 | +| Ubuntu 18.04 | GCC 7.5.0 | | Ubuntu 20.04 | GCC 10.3.0 | | Ubuntu 22.04 | GCC 11.2.0 | -Additionally, CUTLASS may be built with clang. -See [these instructions](media/docs/quickstart.md#clang) for more details. +Note: We plan to add Windows (MSVC) & Clang compiler support soon. ## Hardware -CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on -any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU. - -|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**Minimum CUDA Toolkit Enabling Native Tensor Cores**| -|---|---|---|---| -|NVIDIA Tesla V100|7.0|9.2|10.1| -|NVIDIA TitanV|7.0|9.2|10.1| -|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2| -|NVIDIA Tesla T4|7.5|10.0|10.2| -|NVIDIA A100|8.0|11.0|11.0| -|NVIDIA A10 |8.6|11.1|11.1| -|NVIDIA GeForce 3090|8.6|11.1|11.1| -|NVIDIA H100 PCIe|9.0|11.8|Double-precision: 11.8; Mixed precision: 12.0| +CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on Volta, Turing, Ampere, Ada, and Hopper architecture based NVIDIA GPUs. + +|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit Required by CUTLASS-3**| +|---|---|---| +|NVIDIA V100 Tensor Core GPU |7.0|11.4| +|NVIDIA TitanV |7.0|11.4| +|NVIDIA GeForce RTX 2080 TI, 2080, 2070 |7.5|11.4| +|NVIDIA T4 |7.5|11.4| +|NVIDIA A100 Tensor Core GPU |8.0|11.4| +|NVIDIA A10 |8.6|11.4| +|NVIDIA GeForce RTX 3090 |8.6|11.4| +|NVIDIA GeForce RTX 4090 |8.9|11.8| +|NVIDIA L40 |8.9|11.8| +|NVIDIA H100 Tensor Core GPU |9.0|11.8| + +## Target Architecture + +In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduces the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability). + +The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12.0 or 11.8, the kernel is expected to fail with a runtime error. + +``` +cmake .. -DCUTLASS_NVCC_ARCHS="90a" +``` + +Please refer to the [functionality documentation](media/docs/functionality.md) for details on which kernels require which target architectures. # Documentation @@ -125,7 +149,9 @@ CUTLASS is described in the following documents and the accompanying - [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS - [Functionality](/media/docs/functionality.md) - summarizes functionality available in CUTLASS - [Efficient GEMM in CUDA](media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA -- [GEMM API](media/docs/gemm_api.md) - describes the CUTLASS GEMM model and C++ template concepts +- [CUTLASS 3.x Design](media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components +- [GEMM API 3.x](media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts +- [GEMM API 2.x](media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts - [Implicit GEMM Convolution](media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS - [Code Organization](media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project - [Terminology](media/docs/terminology.md) - describes terms used in the code @@ -161,7 +187,8 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc ``` Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels -for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, and 8.6. To reduce compile time you can specify +for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, 8.6, 8.9, and 9.0. +To reduce compile time you can specify the architectures to build CUTLASS for by changing the CMake configuration setting `CUTLASS_NVCC_ARCHS`. @@ -224,6 +251,23 @@ include/ # client applications should target this directory transform/ # code specialized for layout, type, and domain transformations * # core vocabulary types, containers, and basic numeric operations + + cute/ # CuTe Layout, layout algebra, MMA/Copy atoms, tiled MMA/Copy + + algorithm/ # Definitions of core operations such as copy, gemm, and operations on cute::tuples + + arch/ # Bare bones PTX wrapper structs for copy and math instructions + + atom/ # Meta-information either link to or built from arch/ operators + + mma_atom.hpp # cute::Mma_Atom and cute::TiledMma + + copy_atom.hpp # cute::Copy_Atom and cute::TiledCopy + + *sm*.hpp # Arch specific meta-information for copy and math operations + + * # Core library types such as Shape, Stride, Layout, Tensor, and associated operations + ``` ### CUTLASS SDK Examples @@ -269,7 +313,7 @@ By default, only one tile size is instantiated for each data type, math instruct To instantiate all, set the following environment variable when running CMake from an empty `build/` directory. Beware, this results in *thousands* of kernels and long build times. ```bash -$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=all +$ cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_LIBRARY_KERNELS=all ... $ make cutlass_profiler -j16 ``` diff --git a/cuBLAS.cmake b/cuBLAS.cmake index 6936f0a9d6..db1e36fc1c 100644 --- a/cuBLAS.cmake +++ b/cuBLAS.cmake @@ -40,7 +40,7 @@ elseif(NOT TARGET cublas) find_path( _CUBLAS_INCLUDE_DIR - NAMES cublas.h + NAMES cublas_v2.h HINTS ${CUBLAS_INCLUDE_PATH} ENV CUBLAS_INCLUDE_PATH diff --git a/examples/10_planar_complex/CMakeLists.txt b/examples/10_planar_complex/CMakeLists.txt index c24c05030f..11ca9724ec 100644 --- a/examples/10_planar_complex/CMakeLists.txt +++ b/examples/10_planar_complex/CMakeLists.txt @@ -45,5 +45,6 @@ target_link_libraries( PRIVATE cutlass_lib cutlass_tools_util_includes + cuda ) diff --git a/examples/11_planar_complex_array/CMakeLists.txt b/examples/11_planar_complex_array/CMakeLists.txt index 7434656eed..64125b5256 100644 --- a/examples/11_planar_complex_array/CMakeLists.txt +++ b/examples/11_planar_complex_array/CMakeLists.txt @@ -45,5 +45,6 @@ target_link_libraries( PRIVATE cutlass_lib cutlass_tools_util_includes + cuda ) diff --git a/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h index c2a20d751a..dde3c073a8 100644 --- a/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h +++ b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h @@ -35,7 +35,7 @@ GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm) + lightweight full reduction kernel (ApplyFinalReduction) + GEMM1 with elemenwise operations fused in mainloop (GemmLayernormMainloopFusion) - + */ #pragma once @@ -77,7 +77,7 @@ template < typename ElementLayernormCompute_, typename ElementOutput, typename ThreadblockShape_, - bool IsShiftedVariance_ = false + bool IsShiftedVariance_ = false > class ApplyFinalReduction { public: @@ -91,7 +91,7 @@ class ApplyFinalReduction { using Layout = cutlass::layout::RowMajor; using TensorVariance = TensorRef; - using TensorMean = TensorRef; + using TensorMean = TensorRef; static bool const kIsShiftedVariance = IsShiftedVariance_; @@ -463,7 +463,7 @@ class EpilogueVisitorLayerNorm { for (int rid = 0; rid < kRowIterations; ++rid) { int row_step_offset = rid * kDeltaRow; int row_offset = thread_offset_row_base + step_offset + row_step_offset; - bool is_load = (row_offset < extent_.row()); + bool is_load = (row_offset < extent_.row()); shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load); } @@ -504,9 +504,9 @@ class EpilogueVisitorLayerNorm { using Minus = cutlass::minus; using Exp = cutlass::fast_exp_op; - Minus minus; - Mul mul; - Exp exponential; + [[maybe_unused]] Minus minus; + [[maybe_unused]] Mul mul; + [[maybe_unused]] Exp exponential; LayernormFragment result; @@ -605,7 +605,7 @@ class EpilogueVisitorLayerNorm { CUTLASS_DEVICE ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) { using ConvertShiftK = cutlass::NumericConverter; - ConvertShiftK convert_shift_k; + ConvertShiftK convert_shift_k; ElementOutput shift_k_val; // Computes the address to load shift_k element @@ -614,7 +614,7 @@ class EpilogueVisitorLayerNorm { arch::global_load(shift_k_val, (void *)curr_ptr_shift_k, is_load); // Converts data type to return ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val); - + return converted_shift_k_val; } @@ -689,7 +689,7 @@ class GemmLayernorm { // // Type definitions // - + static bool const kInternalTranspose = cutlass::platform::is_same::value; static bool const kIsShiftedVariance = IsShiftedVariance_; @@ -704,14 +704,14 @@ class GemmLayernorm { using OperatorClass = cutlass::arch::OpClassTensorOp; using ArchTag = cutlass::arch::Sm80; - // These are mandatory layouts and data types + // These are mandatory layouts and data types // that are inheritated from pre-defined params - + using LayoutSumSqr = LayoutInputScaleBias; using LayoutSum = LayoutInputScaleBias; using ElementMean = ElementInputScaleBias; - using ElementVariance = ElementInputScaleBias; + using ElementVariance = ElementInputScaleBias; /////////////////////////////////////////////////////////////////////////////////////////////// @@ -720,7 +720,7 @@ class GemmLayernorm { using LayoutInputA1 = LayoutOutput_; using LayoutInputB1 = LayoutOutput_; using LayoutOutputC0 = LayoutOutput_; - using LayoutOutputC1 = LayoutOutput_; + using LayoutOutputC1 = LayoutOutput_; using ElementInputA0 = ElementInputA0_; using ElementInputB0 = ElementInputB0_; @@ -747,7 +747,7 @@ class GemmLayernorm { static int const kStages1 = Stages1; using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; - + /////////////////////////////////////////////////////////////////////////////////////////////// using MapArguments = cutlass::gemm::kernel::detail::MapArguments< diff --git a/examples/41_fused_multi_head_attention/fmha_grouped.h b/examples/41_fused_multi_head_attention/fmha_grouped.h index 9bae934ded..720159965a 100644 --- a/examples/41_fused_multi_head_attention/fmha_grouped.h +++ b/examples/41_fused_multi_head_attention/fmha_grouped.h @@ -180,9 +180,9 @@ struct FMHAGrouped { /// Default ctor CUTLASS_HOST_DEVICE - Arguments(): + Arguments(): problem_count(0), - threadblock_count(0), + threadblock_count(0), ptr_Q(nullptr), ptr_K(nullptr), ptr_P(nullptr), @@ -201,7 +201,7 @@ struct FMHAGrouped { /// Ctor CUTLASS_HOST_DEVICE - Arguments( + Arguments( GemmCoord *problem_sizes0, GemmCoord *problem_sizes1, int problem_count, @@ -219,7 +219,7 @@ struct FMHAGrouped { typename LayoutO::Stride::LongIndex *ldo, bool causal, GemmCoord *host_problem_sizes=nullptr - ): + ): problem_sizes0(problem_sizes0), problem_sizes1(problem_sizes1), problem_count(problem_count), @@ -311,7 +311,7 @@ struct FMHAGrouped { ldv(args.ldv), ldo(args.ldo), causal(args.causal) - { + { } @@ -464,7 +464,7 @@ struct FMHAGrouped { void operator()(Params const ¶ms, SharedStorage &shared_storage) { auto& m_prime = shared_storage.m_prime; auto& s_prime = shared_storage.s_prime; - auto& si = shared_storage.after_mm0.si; + [[maybe_unused]] auto& si = shared_storage.after_mm0.si; auto& mi = shared_storage.mi; ProblemVisitor problem_visitor( diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h index 6321e7dde8..6cb292c0ec 100644 --- a/examples/41_fused_multi_head_attention/kernel_forward.h +++ b/examples/41_fused_multi_head_attention/kernel_forward.h @@ -481,7 +481,7 @@ struct AttentionKernel { SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); auto& m_prime = shared_storage.m_prime; auto& s_prime = shared_storage.s_prime; - auto& si = shared_storage.after_mm0.si; + [[maybe_unused]] auto& si = shared_storage.after_mm0.si; auto& mi = shared_storage.mi; static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); diff --git a/examples/41_fused_multi_head_attention/mma_from_smem.h b/examples/41_fused_multi_head_attention/mma_from_smem.h index 271a9f3a2c..21ac4d104c 100644 --- a/examples/41_fused_multi_head_attention/mma_from_smem.h +++ b/examples/41_fused_multi_head_attention/mma_from_smem.h @@ -384,7 +384,7 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< // but not supported as it worsens perf: older gpus < sm80 don't // support async tranfers and have to waste registers CUTLASS_DEVICE - bool set_prologue_done(bool value) {} + void set_prologue_done(bool value) {} CUTLASS_DEVICE static void prologue( typename Base::SharedStorage& shared_storage, @@ -695,7 +695,7 @@ class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory< } CUTLASS_DEVICE - bool set_prologue_done(bool value) { + void set_prologue_done(bool value) { prologue_done_ = value; } diff --git a/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu index 48d28bc22e..12739a0577 100644 --- a/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu +++ b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu @@ -34,7 +34,7 @@ "classic data-parallel" and "Split-K" decompositions. For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition - for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598) + for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598) Requires NVIDIA Ampere or newer device (SM80+). diff --git a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu new file mode 100644 index 0000000000..599d1d5083 --- /dev/null +++ b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu @@ -0,0 +1,463 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Simple Hopper GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + + This example demonstrate a simple way to instantiate and run a TF32 GEMM using the new CUTLASS 3.0 + APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows: + + 1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA) + which are more efficient than the Ampere tensor core instructions. + + 2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large + blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous + copies between thread blocks in a cluster. Another advantage is that TMA can load in FP32 data and + convert them implicitly to TF32. + + 3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details). + + Examples: + + $ ./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = float; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TilesShape = Shape<_128,_128,_32>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TilesShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + +using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + + Options(): + help(false), + m(5120), n(4096), k(4096), + alpha(1.f), beta(0.f), + iterations(1000) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "48_hopper_warp_specialized_gemm\n\n" + << " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "48_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{})); + stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{})); + stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{})); + stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{})); + + block_A.reset(options.m * options.k); + block_B.reset(options.k * options.n); + block_C.reset(options.m * options.n); + block_D.reset(options.m * options.n); + block_ref_D.reset(options.m * options.n); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k}, + block_A.get(), + stride_A, + block_B.get(), + stride_B, + {block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}} + }; + + return arguments; +} + +bool verify(const Options &options) { + cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k})); + cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.n, options.k})); + cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n})); + cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n})); + + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + {options.m, options.n, options.k}, + ElementAccumulator(options.alpha), + ref_A, + ref_B, + ElementAccumulator(options.beta), + ref_C, + ref_D); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/48_hopper_warp_specialized_gemm/CMakeLists.txt b/examples/48_hopper_warp_specialized_gemm/CMakeLists.txt new file mode 100644 index 0000000000..b00c7244d2 --- /dev/null +++ b/examples/48_hopper_warp_specialized_gemm/CMakeLists.txt @@ -0,0 +1,35 @@ + +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 48_hopper_warp_specialized_gemm + 48_hopper_warp_specialized_gemm.cu + ) diff --git a/examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu b/examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu new file mode 100644 index 0000000000..1d92bef93c --- /dev/null +++ b/examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu @@ -0,0 +1,522 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM example leveraging collective operation builders. + + This example showcases the use of CUTLASS's CollectiveBuilder to easily construct performant kernels + targetting the NVIDIA Hopper architecture. + + Background and motivation + ------------------------- + CUTLASS kernels are highly parameterizable via template parameters. To ease the selection of template + parameters, CUTLASS 2 leveraged DefaultGemmConfigurations. Given a small set of parameters, such as + the data types of operands and the compute capability of the GPU, DefaultGemmConfigurations defined sensible + defaults for the many other parameters to the kernel (e.g., warp shape, stage count). + + However, DefaultGemmConfigurations leave multiple opportunities for improvement, which are addressed + in CUTLASS 3: + (1) DefaultGemmConfigurations do not allow one to use a more-performant set of parameters without + specifying every parameter. For example, the DefaultGemmConfigurations for GEMMs targetting + Ampere specify that three pipeline stages should be used regardless of the sizes of operands. + If one wished to increase this value, one would also need to specify all other template parameters. + This leaves a gap between a high-level ease-of-use interface and a lower-level detailed interface. + (2) A new DefaultGemmConfiguration was required for each combination of operand types, GPU architecture, + and operation type (e.g., Tensor Core or SIMT). This led to increased code size to cover each unique + configuration and a lack of extensibility from one DefaultGemmConfiguration to another. + + Alongside these opportunities for improvement, the Hopper architecture offers new features that increase + the number of valid configurations of a kernel. In addition to the many template parameters already available + in CUTLASS 2 kernels, CUTLASS 3 kernels targetting Hopper also have various scheduling modes to select from that control: + (1) how data is to be loaded (e.g., using the Hopper TMA feature or Ampere cp.async) + (2) how work is to be divided among warps in a thread block (e.g., whether to use "warp specialization") + (3) whether persistent thread blocks should be used + This increased configuration space further motivates rethinking DefaultGemmConfigurations. + + Introduction to the CollectiveBuilder + ------------------------------------- + CUTLASS 3 introduces the CollectiveBuilder to further ease the process of selecting template parameters + for kernels targetting Hopper. Similar to the DefaultGemmConfigurations used in CUTLASS 2, the CollectiveBuilder + takes in a small set of template parameters (e.g., the data types of operands A and B). It then automatically + determines the data loading strategy to use depending on whether the Hopper TMA feature can be used with the provided + parameters. If one does not indicate a particular scheduling policy or stage count to use (by using `Auto` template + parameters), the CollectiveBuilder will also automatically select these. + + Unlike DefaultGemmConfigurations a parital specialization of the CollectiveBuilder is not needed for many + configurations of operand types. Instead the CollectiveBuilder "builds" a configuration based on generic + properties of the specified operands, layouts, and other parameters. For example, when the stage count + is set to `Auto`, the CollectiveBuilder may automatically calculate the maximum number of stages that + will fit in shared memory given the types of operands and the thread block shape, rather than simply using + a single default value. + + Note that one does not need to use the CollectiveBuilder to declare CUTLASS 3 kernels; one can still provide + every template parameter to the gemm::collective::CollectiveMma. Specifying every template parameter in this + manner remains the primary API for using CUTLASS 3 kernels. The CollectiveBuilder is simply meant to be + a convenience interface. + + Note also that, while the selections made by CollectiveBuilder attempt to maximize performance, this is not + a guarantee. Furthermore, the behavior of the CollectiveBuilder when `Auto` parameters are provided is subject + to change in future CUTLASS releases -- do not rely on `Auto` if you require a specific scheduling policy and/or + stage count to be used. + + Details of this example + ----------------------- + This example walks through the use of the CollectiveBuilder with various schedules and stage counts specified. + This example also illustrates how CUTLASS 3 GEMMs targetting Hopper automatically support batched GEMMs by simply + extending the problem size with an additional tensor rank. + + Example usage: + $ ./examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder \ + --m=2048 --n=2048 --k=2048 --l=2 +*/ + +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l; + float alpha, beta; + + Options(): + help(false), + error(false), + m(2048), n(2048), k(2048), l(1), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 2048); + cmd.get_cmd_line_argument("n", n, 2048); + cmd.get_cmd_line_argument("k", k, 2048); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "49_hopper_gemm_schedules_with_collective_builder\n\n" + << " This example showcases the use of CUTLASS's collective operation builders to easily construct\n" + << " performant kernels targetting NVIDIA's Hopper architecture.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +// Wrapper to construct, run, and verify a GEMM. This example showcases CUTLASS's collective +// operation builders by specializing the GEMM only on the kernel schedule it will use and the +// number of pipeline stages. +// +// For either option, one can use a special `Auto` type that tells the CollectiveBuilder +// to select an appropriate value on its own. The CollectiveBuilder will attempt to select +// values that will result in the most-performant kernel, but this is not a guarantee. Furthermore, +// the behavior of the CollectiveBuilder with `Auto` types is subject to change in future releases +// -- do not rely on `Auto` if you require a specific scheduling policy. +template < + // Type of kernel schedule to generate + class KernelScheduleType = cutlass::gemm::collective::KernelScheduleAuto, + // Number of pipeline stages to use + class StageCountType = cutlass::gemm::collective::StageCountAuto +> +struct ExampleRunner { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + static constexpr int kAlignmentA = 8; + static constexpr int kAlignmentB = 8; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, kAlignmentA, + cutlass::half_t, LayoutB, kAlignmentB, + float, + Shape<_128,_128,_64>, Shape<_2,_1,_1>, + StageCountType, + KernelScheduleType + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutTagA = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); + using LayoutTagB = decltype(cutlass::gemm::detail::stride_to_layout_tag_B()); + using LayoutTagC = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); + using LayoutTagD = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, float alpha, float beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + typename Gemm::EpilogueOutputOp::ElementCompute(alpha), + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + typename Gemm::EpilogueOutputOp::ElementCompute(beta), + ref_C, + ref_D, + typename Gemm::EpilogueOutputOp::ElementAccumulator(0.f), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + block_A.get(), + stride_A, + block_B.get(), + stride_B, + {block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + // Run the GEMM + status = gemm_op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + if (!passed) { + std::cerr << "Reference check failed" << std::endl; + } + + return passed; + } + +}; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to print a description of the example run and its result +void print_result(const std::string& description, bool passed) { + std::cout << description << ": " << (passed ? "Passed" : "Failed") << std::endl; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || props.major < 9) { + std::cout + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // This first example constructs a GEMM using the default schedule and stage count provided by + // the CollectiveBuilder. The scheduling policy that is expected to be most performant will be + // selected and the maximum number of stages that can fit in shared memory will be selected. + // + // This example is equivalent to declaring + // ExampleRunner + // Each of the `Auto` types indicate that the CollectiveBuilder should determine the scheduling policy and + // stage count. Note that the behavior of the CollectiveBuilder with `Auto` parameters is subject to change + // -- do not rely on `Auto` if you require a specific scheduling policy. + ExampleRunner<> auto_schedule_auto_stage_runner; + passed = auto_schedule_auto_stage_runner.run(options, hw_info); + print_result("Automatically-selected schedule and stage count", passed); + + // One can override the stage count used in the GEMM by replacing cutlass::gemm::collective::StageCountAuto + // with the number of stages to use (5 in this case). + ExampleRunner auto_schedule_5_stage_runner; + passed = auto_schedule_5_stage_runner.run(options, hw_info); + print_result("Automatically-selected schedule with 5 stages", passed); + + // One can also override the scheduling policy to use. In this case, use the KernelTma scheduling + // policy, which specifies that the Hopper TMA feature should be used. + ExampleRunner tma_schedule_auto_stage_runner; + passed = tma_schedule_auto_stage_runner.run(options, hw_info); + print_result("TMA schedule with automatically-selected stage count", passed); + + // Here, we override the scheduling policy to use Hopper's TMA feature alongside the warp-specialized + // scheduling policy. + // + // Note that, as of the CUTLASS 3.0 release, this is the default scheduling policy + // used by the CollectiveBuilder, so this declaration is equivalent to ExampleRunner<> and + // ExampleRunner. However, this default is subject to + // change in future releases -- do not rely on `Auto` if you require a specific scheduling policy. + ExampleRunner ws_schedule_auto_stage_runner; + passed = ws_schedule_auto_stage_runner.run(options, hw_info); + print_result("Warp-specialized TMA schedule with automatically-selected stage count", passed); + + // Finally, we override the scheduling policy to use Hopper's TMA feature, alongside the warp-specialized + // scheduling policy, leveraging persistent thread blocks. + ExampleRunner ws_persistent_schedule_auto_stage_runner; + passed = ws_persistent_schedule_auto_stage_runner.run(options, hw_info); + print_result("Persistent warp-specialized TMA schedule with automatically-selected stage count", passed); + +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/49_hopper_gemm_schedules_with_collective_builder/CMakeLists.txt b/examples/49_hopper_gemm_schedules_with_collective_builder/CMakeLists.txt new file mode 100644 index 0000000000..30c6e5ead0 --- /dev/null +++ b/examples/49_hopper_gemm_schedules_with_collective_builder/CMakeLists.txt @@ -0,0 +1,35 @@ + +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 49_hopper_gemm_schedules_with_collective_builder + 49_hopper_gemm_schedules_with_collective_builder.cu + ) diff --git a/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu b/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu new file mode 100644 index 0000000000..7323cc39de --- /dev/null +++ b/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu @@ -0,0 +1,529 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM example to create a GEMM kernel with custom Collectives + + The following example shows how to assemble a custom GEMM kernel that spells out the Collectives + directly instead of using a builder and, in the process, instance a more efficient Epilogue + (from `cutlass/epilogue/collective/epilogue.hpp`) instead of using the default epilogue. + + The GemmUniversal API takes 3 main template arguments: + (1) the problem shape / extents + (2) the collective mainloop type + (3) the collective epilogue type + + While the collecive mainloop can be stamped out using a CollectiveBuilder interface, it is + possible to build a custom collective mainloop directly as well. Furthermore, since epilogues + do not yet have a builder interface, this example shows how to instantiate a more-efficient + epilogue alongside the collective mainloop. + + Note: there are several ways to implement the GEMM epilogue in Hopper - each with its own set + of trade-offs. So it is recommended that users look at the options available under + cutlass/epilogue/collective and evaluate for their particular scenario. + + Please refer to examples 48, 49 to learn more about kernel schedules and other CuTe examples + present in `test/unit/cute` to famialiarize with the basics of CuTe. + + Examples: + + $ ./examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l; + int alpha, beta; + + Options(): + help(false), + error(false), + m(2048), n(2048), k(2048), l(1), + alpha(1), beta(0) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 2048); + cmd.get_cmd_line_argument("n", n, 2048); + cmd.get_cmd_line_argument("k", k, 2048); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1); + cmd.get_cmd_line_argument("beta", beta, 0); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "50_hopper_gemm_with_vectorized_epilogue\n\n" + << "Hopper GEMM Example with Epilogue Swizzle.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +// Wrapper to run and verify a GEMM. +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, int32_t alpha, int32_t beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + ElementCompute(alpha), + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + ElementCompute(beta), + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + block_A.get(), + stride_A, + block_B.get(), + stride_B, + {block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + // Run the GEMM + status = gemm_op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + if (!passed) { + std::cerr << "Reference check failed" << std::endl; + } + + return passed; + } + +}; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || props.major < 9) { + std::cout + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // Problem configuration + using ElementA = int8_t; + using ElementB = int8_t; + using ElementAcc = int32_t; + using ElementOutput = int8_t; + + // Note : Only TN WGMMA Gemm is supported currently in 3.0 + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + // Tiling configuration selection + using TileShape = Shape<_128,_64,_128>; + + // Choosing a thread block cluster larger than 1 allows us to Multicast data across thread blocks + using ClusterShape = Shape<_1,_2,_1>; + + // + // Assembling the CollectiveMainloop type + // + + // Pipeline Depth to be used i.e number of A, B buffers in shared memory + constexpr int PipelineStages = 8; + + // Let's choose a Warp-Specialized Mainloop implemention which uses TMA + // Note : This requires / assumes the tensors to be 16B aligned + using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized; + + // TN => K Major for both A & B + static constexpr cute::GMMA::Major GmmaMajorA = cute::GMMA::Major::K; + static constexpr cute::GMMA::Major GmmaMajorB = cute::GMMA::Major::K; + + // We use the SS op selector as both A, B operands are read directly from SMEM (for TN WGMMA) + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementA, ElementB, ElementAcc, TileShape, GmmaMajorA, GmmaMajorB>())); + + // A loads can be optimized with multicast if cluster-n > 1 + using GmemTiledCopyA = std::conditional< cute::size(shape<1>(ClusterShape{})) == 1, + cute::SM90_TMA_LOAD, + cute::SM90_TMA_LOAD_MULTICAST>::type; + + // B loads can be optimized with multicast if cluster-m > 1 + using GmemTiledCopyB = std::conditional< cute::size(shape<0>(ClusterShape{})) == 1, + cute::SM90_TMA_LOAD, + cute::SM90_TMA_LOAD_MULTICAST>::type; + + using SmemLayoutAtomA = decltype(cute::GMMA::smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape{})), decltype(cute::get<2>(TileShape{})) + >()); + + using SmemLayoutAtomB = decltype(cute::GMMA::smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape{})), decltype(cute::get<2>(TileShape{})) + >()); + + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + void, // Does not need a SmemCopyAtom, since A is read directly from SMEM + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, // Does not need a SmemCopyAtom, since B is read directly from SMEM + cute::identity + >; + + // + // Assembling the Collective Epilogue Type + // + + // Break the 128 along TILE_M into chunks of 32, to get a 128B leading dimension + using PreSwizzleLayout = Layout< Shape< Shape <_32,_4 >,_64>, + Stride,_32>>; + + // 128 threads loading 16 elements each (to get vectorized global stores) + using TileShapeS2R = Shape<_128,_16>; + + // Layout to ensure bank-conflict free loads & stores + using SmemLayout = ComposedLayout< + Swizzle<3,4,3>, + smem_ptr_flag_bits::value>, + PreSwizzleLayout>; + + // Tiled copy from Smem to Registers + // Note : CuTe will vectorize this copy if the tiling + swizzling above were right + using TiledCopyS2R = TiledCopy< + Copy_Atom, + Layout< Shape<_128,_16>, + Stride<_16,_1>>, + TileShapeS2R>; + + using Epilogue = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + SmemLayout, + Copy_Atom, + TiledCopyS2R, + Copy_Atom>; + + // + // Assembling the GemmKernel + // + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + Epilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + passed = runner.run(options, hw_info); + + std::cout << "WGMMA GEMM with Epilogue Swizzle : " << (passed ? "Passed" : "Failed") << std::endl; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/50_hopper_gemm_with_epilogue_swizzle/CMakeLists.txt b/examples/50_hopper_gemm_with_epilogue_swizzle/CMakeLists.txt new file mode 100644 index 0000000000..b213d3936f --- /dev/null +++ b/examples/50_hopper_gemm_with_epilogue_swizzle/CMakeLists.txt @@ -0,0 +1,35 @@ + +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 50_hopper_gemm_with_epilogue_swizzle + 50_hopper_gemm_with_epilogue_swizzle.cu + ) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index fac98b8ebc..a063bd81dd 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -54,12 +54,14 @@ function(cutlass_example_add_executable NAME) CUTLASS cutlass_tools_util_includes $<$:nvidia::cublas> + cuda ) target_include_directories( ${NAME} PRIVATE ${CUTLASS_EXAMPLES_COMMON_SOURCE_DIR} + ${CUTLASS_EXAMPLES_UTILS_DIR} ) install( @@ -118,6 +120,7 @@ foreach(EXAMPLE 36_gather_scatter_fusion 37_gemm_layernorm_gemm_fusion 38_syr2k_grouped + cute 39_gemm_permute 41_fused_multi_head_attention 42_ampere_tensorop_group_conv @@ -125,6 +128,9 @@ foreach(EXAMPLE 45_dual_gemm 46_depthwise_simt_conv2dfprop 47_ampere_gemm_universal_streamk + 48_hopper_warp_specialized_gemm + 49_hopper_gemm_schedules_with_collective_builder + 50_hopper_gemm_with_epilogue_swizzle ) add_subdirectory(${EXAMPLE}) diff --git a/examples/cute/CMakeLists.txt b/examples/cute/CMakeLists.txt new file mode 100644 index 0000000000..c210d634af --- /dev/null +++ b/examples/cute/CMakeLists.txt @@ -0,0 +1,30 @@ + +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +add_subdirectory(tutorial) diff --git a/examples/cute/tutorial/CMakeLists.txt b/examples/cute/tutorial/CMakeLists.txt new file mode 100644 index 0000000000..97867ded44 --- /dev/null +++ b/examples/cute/tutorial/CMakeLists.txt @@ -0,0 +1,34 @@ + +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + sgemm_nt_1 + sgemm_nt_1.cu +) + diff --git a/examples/cute/tutorial/sgemm_nt_1.cu b/examples/cute/tutorial/sgemm_nt_1.cu new file mode 100644 index 0000000000..fc4839a5bf --- /dev/null +++ b/examples/cute/tutorial/sgemm_nt_1.cu @@ -0,0 +1,426 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include + +#include + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 +# include "cutlass/util/cublas_wrappers.hpp" +#endif +#include "cutlass/util/helper_cuda.hpp" + +template +__global__ static +__launch_bounds__(decltype(size(CThreadLayout{}))::value) +void +gemm_device(MShape M, NShape N, KShape K, + TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, + TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, + TC * C, CStride dC, CBlockLayout , CThreadLayout tC, + Alpha alpha, Beta beta) +{ + using namespace cute; + using X = Underscore; + + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + + CUTE_STATIC_ASSERT_V(size(tA) == size(tC)); + CUTE_STATIC_ASSERT_V(size(tB) == size(tC)); + + //CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M + //CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N + CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K + + // Shared memory buffers + __shared__ TA smemA[cosize_v]; + __shared__ TB smemB[cosize_v]; + auto sA = make_tensor(make_smem_ptr(smemA), blockA); // (BLK_M,BLK_K) + auto sB = make_tensor(make_smem_ptr(smemB), blockB); // (BLK_N,BLK_K) + + // Represent the full tensors + auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K) + auto mB = make_tensor(make_gmem_ptr(B), make_shape(N,K), dB); // (N,K) + auto mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N) + + // Get the appropriate blocks for this thread block -- + // potential for thread block locality + auto blk_shape = make_shape(size<0>(sA), size<0>(sB), size<1>(sB));// (BLK_M,BLK_N,BLK_K) + auto blk_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) + + auto gA = local_tile(mA, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + auto gB = local_tile(mB, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + auto gC = local_tile(mC, blk_shape, blk_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + // + // Partition the copying of A and B tiles across the threads + // + + // TUTORIAL: Example of simple partitioning of A|B tiles over tA|tB + // Default is a raked partition, but can be changed with Step parameter + + auto tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k) + auto tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K) + + auto tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k) + auto tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K) + + // + // Define C accumulators and A/B partitioning + // + + // TUTORIAL: Example of partitioning via projections of tC + + // Partition sA (M,K) by the rows of tC + auto tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K) + // Partition sB (N,K) by the cols of tC + auto tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K) + // Partition gC (M,N) by the tile of tC + auto tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N) + + // Allocate the accumulators -- same size as the projected data + auto tCrC = make_fragment_like(tCgC); // (THR_M,THR_N) + + // Clear the accumulators + clear(tCrC); + +#if 0 + if(thread0()) { + print("mA\n"); + print(mA.shape()); print("\n"); print(mA.stride()); + print("\n\ngA\n"); + print(gA.shape()); print("\n"); print(gA.stride()); + print("\n\ntAgA\n"); + print(tAgA.shape()); print("\n"); print(tAgA.stride()); + print("\n\nsA\n"); + print(sA.shape()); print("\n"); print(sA.stride()); + print("\n\ntAsA\n"); + print(tAsA.shape()); print("\n"); print(tAsA.stride()); + print("\n\n"); + } +#endif + +#if 0 + if(thread0()) { + print("mB\n"); + print(mB.shape()); print("\n"); print(mB.stride()); + print("\n\ngB\n"); + print(gB.shape()); print("\n"); print(gB.stride()); + print("\n\ntBgB\n"); + print(tBgB.shape()); print("\n"); print(tBgB.stride()); + print("\n\nsB\n"); + print(sB.shape()); print("\n"); print(sB.stride()); + print("\n\ntBsB\n"); + print(tBsB.shape()); print("\n"); print(tBsB.stride()); + print("\n\n"); + } +#endif + +#if 0 + if(thread0()) { + print("mC\n"); + print(mC.shape()); print("\n"); print(mC.stride()); + print("\n\ngC\n"); + print(gC.shape()); print("\n"); print(gC.stride()); + print("\n\ntCsA\n"); + print(tCsA.shape()); print("\n"); print(tCsA.stride()); + print("\n\ntCsB\n"); + print(tCsB.shape()); print("\n"); print(tCsB.stride()); + print("\n\ntCgC\n"); + print(tCgC.shape()); print("\n"); print(tCgC.stride()); + print("\n\ntCrC\n"); + print(tCrC.shape()); print("\n"); print(tCrC.stride()); + print("\n\n"); + } +#endif + +#if 1 + + // TUTORIAL: Example of a very simple compute loop + // Data is read from global to shared memory via the tA|tB partitioning + // gemm(.) operates on the shared memory directly via the tC partitioning + + auto k_max = size<2>(tAgA); + + for (int k = 0; k < k_max; ++k) + { + // Copy gmem to smem + copy(tAgA(_,_,k), tAsA); + copy(tBgB(_,_,k), tBsB); + + // In case copy uses cp.async, make sure that the cp.async + // instructions are ordered with respect to other cp.async + // instructions (fence), then wait on all the outstanding copy + // operations (wait<0>()). __syncthreads() alone does not do + // this. + // + // NOTE: cp_async_wait<0>() currently issues cp.async.wait_all. + // This is equivalent to cp.async.commit_group followed by + // cp.async_wait_group 0. This should make the first + // cp_async_fence() (which also issues cp.async.commit_group) + // redundant. The tutorial works as-is, so we'll leave the + // redundant fence in for now and study its removal later. + cp_async_fence(); + cp_async_wait<0>(); + + __syncthreads(); + + // Compute gemm on smem + gemm(tCsA, tCsB, tCrC); + + __syncthreads(); + } + +#endif + + // + // Epilogue + // + + axpby(alpha, tCrC, beta, tCgC); +} + + +template +void +gemm(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + + // Define strides (mixed) + auto dA = make_stride(Int<1>{}, ldA); + auto dB = make_stride(Int<1>{}, ldB); + auto dC = make_stride(Int<1>{}, ldC); + + // Define block sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 8>{}; + + // Define the block layouts (static) + auto sA = make_layout(make_shape(bM,bK)); + auto sB = make_layout(make_shape(bN,bK)); + auto sC = make_layout(make_shape(bM,bN)); + + // Define the thread layouts (static) + auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); + auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); + auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); + + dim3 dimBlock(size(tC)); + dim3 dimGrid(ceil_div(size(M), size(bM)), + ceil_div(size(N), size(bN))); + gemm_device + <<< dimGrid, dimBlock, 0, stream >>> + (M, N, K, + A, dA, sA, tA, + B, dB, sB, tB, + C, dC, sC, tC, + alpha, beta); +} + +#include +#include +#include + +void test_gemm(int m, int n, int k) +{ + cute::device_init(0); + + std::cout << "M = " << m << std::endl; + std::cout << "N = " << n << std::endl; + std::cout << "K = " << k << std::endl; + + using TA = float; + using TB = float; + using TC = float; + using TI = float; + + thrust::host_vector h_A(m*k); + thrust::host_vector h_B(n*k); + thrust::host_vector h_C(m*n); + + for (int j = 0; j < m*k; ++j) h_A[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); + for (int j = 0; j < n*k; ++j) h_B[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); + for (int j = 0; j < m*n; ++j) h_C[j] = static_cast(-1); + + thrust::device_vector d_A = h_A; + thrust::device_vector d_B = h_B; + thrust::device_vector d_C = h_C; + + TI alpha = 1.0; + TI beta = 0.0; + + double gflops = (2.0*m*n*k) * 1e-9; + + const int timing_iterations = 100; + GPU_Clock timer; + +#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 + // + // cuBLas + // + + cublasHandle_t handle; + cublasCreate(&handle); + + // Run once + d_C = h_C; + blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, + m, n, k, + &alpha, + d_A.data().get(), m, + d_B.data().get(), n, + &beta, + d_C.data().get(), m); + CUTE_CHECK_LAST(); + + thrust::host_vector cublas_result = d_C; + + // Timing iterations + timer.start(); + for (int i = 0; i < timing_iterations; ++i) { + blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, + m, n, k, + &alpha, + d_A.data().get(), m, + d_B.data().get(), n, + &beta, + d_C.data().get(), m); + } + double cublas_time = timer.seconds() / timing_iterations; + CUTE_CHECK_LAST(); + printf("CUBLAS_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cublas_time, cublas_time*1000); + +#else + + std::cout << "Verification by comparison with cuBLAS is disabled, " + "either because the CMake option CUTLASS_ENABLE_CUBLAS " + "was explicitly set to OFF, or because CMake could not find cuBLAS. " + "If you would like to enable verification with cuBLAS, " + "please set the CMake option CUTLASS_ENABLE_CUBLAS to ON, " + "rerun CMake, and recompile this example.\n"; + +#endif // CUTLASS_ENABLE_CUBLAS + + // + // CuTe + // + + // Run once (and check) + d_C = h_C; + gemm(m, n, k, + alpha, + d_A.data().get(), m, + d_B.data().get(), n, + beta, + d_C.data().get(), m); + CUTE_CHECK_LAST(); + thrust::host_vector cute_result = d_C; + + // Timing iterations + timer.start(); + for (int i = 0; i < timing_iterations; ++i) { + gemm(m, n, k, + alpha, + d_A.data().get(), m, + d_B.data().get(), n, + beta, + d_C.data().get(), m); + } + double cute_time = timer.seconds() / timing_iterations; + CUTE_CHECK_LAST(); + printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000); + +#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 + printf("Empirical Perf: %.1f%%\n", (cublas_time / cute_time) * 100); + + auto host_matrix_to_const_column_major_cute_tensor = + [](const auto& X, int num_rows, int num_cols, int LDX) { + const auto shape = cute::Shape{num_rows, num_cols}; + const auto strides = cute::Stride{1, LDX}; + return cute::make_tensor(X.data(), cute::make_layout(shape, strides)); + }; + + const auto A_view = host_matrix_to_const_column_major_cute_tensor(h_A, m, k, m); + // B^T is k x n, so B is n x k. + const auto B_view = host_matrix_to_const_column_major_cute_tensor(h_B, n, k, n); + const auto C_computed_view = host_matrix_to_const_column_major_cute_tensor(cute_result, m, n, m); + const auto C_expected_view = host_matrix_to_const_column_major_cute_tensor(cublas_result, m, n, m); + print_matrix_multiply_mollified_relative_error("float", A_view, B_view, C_computed_view, C_expected_view); + +#endif // CUTLASS_ENABLE_CUBLAS +} + + +int main(int argc, char** argv) +{ + int m = 5120; + if (argc >= 2) + sscanf(argv[1], "%d", &m); + + int n = 5120; + if (argc >= 3) + sscanf(argv[2], "%d", &n); + + int k = 4096; + if (argc >= 4) + sscanf(argv[3], "%d", &k); + + test_gemm(m, n, k); + + return 0; +} diff --git a/include/cute/algorithm/axpby.hpp b/include/cute/algorithm/axpby.hpp new file mode 100644 index 0000000000..a613417d39 --- /dev/null +++ b/include/cute/algorithm/axpby.hpp @@ -0,0 +1,79 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +// +// Accept mutable temporaries +// +template +CUTE_HOST_DEVICE +void +axpby(Alpha const& alpha, + Tensor const& x, + Beta const& beta, + Tensor && y) +{ + return axpby(alpha, x, beta, y); +} + +// +// AXPBY +// +template +CUTE_HOST_DEVICE +void +axpby(Alpha const& alpha, + Tensor const& x, + Beta const& beta, + Tensor & y) +{ + auto isBetaZero = (beta == Int<0>{}); + + CUTE_UNROLL + for (int i = 0; i < size(x); ++i) { + y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i)); + } +} + +} // end namespace cute diff --git a/include/cute/algorithm/clear.hpp b/include/cute/algorithm/clear.hpp new file mode 100644 index 0000000000..ce7b51095d --- /dev/null +++ b/include/cute/algorithm/clear.hpp @@ -0,0 +1,66 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#include + +namespace cute +{ + +// +// Accept mutable temporaries +// +template +CUTE_HOST_DEVICE +void +clear(Tensor&& tensor) +{ + return clear(tensor); +} + +// +// Set elements to zero +// +template +CUTE_HOST_DEVICE +void +clear(Tensor& tensor) +{ + using T = typename Tensor::value_type; + + fill(tensor, T{}); +} + +} // end namespace cute diff --git a/include/cute/algorithm/copy.hpp b/include/cute/algorithm/copy.hpp new file mode 100644 index 0000000000..04ceb051a4 --- /dev/null +++ b/include/cute/algorithm/copy.hpp @@ -0,0 +1,262 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +#include + +namespace cute +{ + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +copy_if(PrdTensor const& pred, + Tensor const& src, + Tensor && dst) +{ + return copy_if(pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_if(Copy_Atom const& copy_atom, + PrdTensor const& pred, + Tensor const& src, + Tensor && dst) +{ + return copy_if(copy_atom, pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_vec(Tensor const& src, + Tensor && dst) +{ + return copy_vec(src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(Tensor const& src, + Tensor && dst) +{ + return copy(src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom const& copy_atom, + Tensor const& src, + Tensor && dst) +{ + return copy(copy_atom, src, dst); +} + +// +// copy_if -- Predicated Copy +// + +template +CUTE_HOST_DEVICE +void +copy_if(PrdTensor const& pred, + Tensor const& src, + Tensor & dst) +{ + auto copy_op = select_elementwise_copy(src, dst); + + CUTE_UNROLL + for (int i = 0; i < size(src); ++i) { + if (pred(i)) { + copy_op.copy(src(i), dst(i)); + } + } +} + +// +// copy_if -- Predicated CopyAtom +// + +template +CUTE_HOST_DEVICE +void +copy_if(Copy_Atom const& copy_atom, + PredTensor const& pred, // (Rest...) + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) +{ + static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); + if constexpr (SrcLayout::rank == 1) { // Dispatch the copy + copy_atom.call(src, dst); + } else { // Loop over all but the first mode + constexpr int R = SrcLayout::rank; + auto src_v = group_modes<1,R>(src); + auto dst_v = group_modes<1,R>(dst); + CUTE_UNROLL + for (int i = 0; i < size<1>(src_v); ++i) { + if (pred(i)) { + copy_atom.call(src_v(_,i), dst_v(_,i)); + } + } + } +} + +// +// copy_vec -- attempt vectorized copy with VecType +// + +template +CUTE_HOST_DEVICE +void +copy_vec(Tensor const& src, + Tensor & dst) +{ + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; + if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType)) + { + /* @pre is_aligned(src.data()) && + * is_aligned(dst.data()) + */ + auto src_v = recast(src); + auto dst_v = recast(dst); + +#if 0 + if (thread0()) { + print("copy_vec -- vectorizing copy from %3db to %3db\n", int(8*sizeof(SrcType)), int(8*sizeof(VecType))); + print(" "); print(layout(src)); print(" => "); print(layout(src_v)); print("\n"); + print(" "); print(layout(dst)); print(" => "); print(layout(dst_v)); print("\n"); + } +#endif + + return copy_if(TrivialPredTensor{}, src_v, dst_v); + } else { +#if 0 + if (thread0()) { + print("copy_vec -- not vectorizing, copy with %3db and %3db\n", int(8*sizeof(SrcType)), int(8*sizeof(DstType))); + print(" "); print(layout(src)); print("\n"); + print(" "); print(layout(dst)); print("\n"); + } +#endif + + return copy_if(TrivialPredTensor{}, src, dst); + } +} + +// +// copy -- auto-vectorizing copy +// + +template +CUTE_HOST_DEVICE +void +copy(Tensor const& src, + Tensor & dst) +{ + constexpr int N = decltype(max_common_vector(src, dst))::value; + +#if 0 + if (thread0()) { + print("copy -- found a max_common_vector of %d\n", N); + print(" "); print(src.data()); print(" o "); print(layout(src)); print("\n"); + print(" "); print(dst.data()); print(" o "); print(layout(dst)); print("\n"); + } +#endif + + if constexpr (N <= 1) { + return copy_if(TrivialPredTensor{}, src, dst); + } else { + constexpr int vec_bits = N * sizeof_bits::value; + using VecType = uint_bit_t; + return copy_vec(src, dst); + } +} + +// +// copy -- CopyAtom +// + +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom const& copy_atom, + Tensor const& src, + Tensor & dst) +{ + return copy_if(copy_atom, TrivialPredTensor{}, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom const&, + Tensor const& src, + Tensor & dst) +{ + return copy(src, dst); +} + +} // end namespace cute diff --git a/include/cute/algorithm/fill.hpp b/include/cute/algorithm/fill.hpp new file mode 100644 index 0000000000..bc0c4ad16d --- /dev/null +++ b/include/cute/algorithm/fill.hpp @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// +// Accept mutable temporaries +// +template +CUTE_HOST_DEVICE +void +fill(Tensor&& tensor, T const& value) +{ + return fill(tensor, value); +} + +namespace detail +{ + +// Prefer fill(tensor.data(), value), if possible +template +CUTE_HOST_DEVICE +auto +fill(Tensor& tensor, T const& value, prefer<1>) + -> decltype(fill(tensor.data(), value)) +{ + fill(tensor.data(), value); +} + +// Default implementation +template +CUTE_HOST_DEVICE +void +fill(Tensor& tensor, T const& value, prefer<0>) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = value; + } +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE +void +fill(Tensor& tensor, T const& value) +{ + return detail::fill(tensor, value, prefer<1>{}); +} + +} // end namespace cute diff --git a/include/cute/algorithm/functional.hpp b/include/cute/algorithm/functional.hpp new file mode 100644 index 0000000000..e66cd975d5 --- /dev/null +++ b/include/cute/algorithm/functional.hpp @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +/** C++14 extensions */ + +namespace cute { + +/**************/ +/** Identity **/ +/**************/ + +struct identity { + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return std::forward(arg); + } +}; + +template +struct constant_fn { + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&&...) const { + return r_; + } + R r_; +}; + +/***********/ +/** Unary **/ +/***********/ + +#define CUTE_LEFT_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return OP std::forward(arg); \ + } \ + } +#define CUTE_RIGHT_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return std::forward(arg) OP ; \ + } \ + } +#define CUTE_NAMED_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return OP (std::forward(arg)); \ + } \ + } + +CUTE_LEFT_UNARY_OP(unary_plus, +); +CUTE_LEFT_UNARY_OP(negate, -); +CUTE_LEFT_UNARY_OP(bit_not, ~); +CUTE_LEFT_UNARY_OP(logical_not, !); +CUTE_LEFT_UNARY_OP(dereference, *); +CUTE_LEFT_UNARY_OP(address_of, &); +CUTE_LEFT_UNARY_OP(pre_increment, ++); +CUTE_LEFT_UNARY_OP(pre_decrement, --); + +CUTE_RIGHT_UNARY_OP(post_increment, ++); +CUTE_RIGHT_UNARY_OP(post_decrement, --); + +CUTE_NAMED_UNARY_OP(abs_fn, abs); +CUTE_NAMED_UNARY_OP(conjugate, cute::conj); + +#undef CUTE_LEFT_UNARY_OP +#undef CUTE_RIGHT_UNARY_OP +#undef CUTE_NAMED_UNARY_OP + +/************/ +/** Binary **/ +/************/ + +#define CUTE_BINARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& lhs, U&& rhs) const { \ + return std::forward(lhs) OP std::forward(rhs); \ + } \ + } +#define CUTE_NAMED_BINARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& lhs, U&& rhs) const { \ + return OP (std::forward(lhs), std::forward(rhs)); \ + } \ + } + + +CUTE_BINARY_OP(plus, +); +CUTE_BINARY_OP(minus, -); +CUTE_BINARY_OP(multiplies, *); +CUTE_BINARY_OP(divides, /); +CUTE_BINARY_OP(modulus, %); + +CUTE_BINARY_OP(plus_assign, +=); +CUTE_BINARY_OP(minus_assign, -=); +CUTE_BINARY_OP(multiplies_assign, *=); +CUTE_BINARY_OP(divides_assign, /=); +CUTE_BINARY_OP(modulus_assign, %=); + +CUTE_BINARY_OP(bit_and, &); +CUTE_BINARY_OP(bit_or, |); +CUTE_BINARY_OP(bit_xor, ^); +CUTE_BINARY_OP(left_shift, <<); +CUTE_BINARY_OP(right_shift, >>); + +CUTE_BINARY_OP(bit_and_assign, &=); +CUTE_BINARY_OP(bit_or_assign, |=); +CUTE_BINARY_OP(bit_xor_assign, ^=); +CUTE_BINARY_OP(left_shift_assign, <<=); +CUTE_BINARY_OP(right_shift_assign, >>=); + +CUTE_BINARY_OP(logical_and, &&); +CUTE_BINARY_OP(logical_or, ||); + +CUTE_BINARY_OP(equal_to, ==); +CUTE_BINARY_OP(not_equal_to, !=); +CUTE_BINARY_OP(greater, >); +CUTE_BINARY_OP(less, <); +CUTE_BINARY_OP(greater_equal, >=); +CUTE_BINARY_OP(less_equal, <=); + +CUTE_NAMED_BINARY_OP(max_fn, cute::max); +CUTE_NAMED_BINARY_OP(min_fn, cute::min); + +#undef CUTE_BINARY_OP +#undef CUTE_NAMED_BINARY_OP + +/**********/ +/** Meta **/ +/**********/ + +template +struct bound_fn { + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(T&& arg) { + return fn_(arg_, std::forward(arg)); + } + + Fn fn_; + Arg arg_; +}; + +template +CUTE_HOST_DEVICE constexpr +auto +bind(Fn const& fn, Arg const& arg) { + return bound_fn{fn, arg}; +} + +} // end namespace cute diff --git a/include/cute/algorithm/gemm.hpp b/include/cute/algorithm/gemm.hpp new file mode 100644 index 0000000000..6e2ce612c0 --- /dev/null +++ b/include/cute/algorithm/gemm.hpp @@ -0,0 +1,718 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include + +/** The gemm algorithm takes four (or three) tensors and computes + * D += A * B + C + * It dispatches based on the number of modes each tensor has: + * + * 1. `(V) x (V) => (V)`. + * The element-wise product of vectors. Dispatches to FMA or MMA. + * 2. `(M) x (N) => (M,N)`. + * The outer product of vectors. Dispatches to [3] with new mode K=(1). + * 3. `(M,K) x (N,K) => (M,N)`. + * The product of matrices. Dispatches to [5] with MMA vector-mode V. + * 4. `(V,M) x (V,N) => (V,M,N)`. + * The batched outer product of vectors. Accounts for register reuse and dispatches to [1] for each (m,n). + * 5. `(V,M,K) x (V,N,K) => (V,M,N)`. + * The batched product of matrices. Dispatches to [4] for each (k). + */ + +namespace cute +{ + +// +// Three arguments to four +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor const& A, + Tensor const& B, + Tensor & C) +{ + return gemm(C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor const& A, + Tensor const& B, + Tensor & C) +{ + return gemm(mma, C, A, B, C); +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor const& A, + Tensor const& B, + Tensor && C) +{ + return gemm(C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + return gemm(D, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor const& A, + Tensor const& B, + Tensor && C) +{ + return gemm(mma, C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + return gemm(mma, D, A, B, C); +} + +// +// Default MMA is UniversalFMA +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + using MMA = MMA_Atom::value_type, + typename Tensor::value_type, + typename Tensor::value_type, + typename Tensor::value_type>>; + + return gemm(MMA{}, D, A, B, C); +} + +// +// Thread-Local Register-Memory GEMMs +// + +// Dispatch [1]: (V) x (V) => (V) +template ::value && + ALayout::rank == 1 && is_rmem::value && + BLayout::rank == 1 && is_rmem::value && + CLayout::rank == 1 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V) Logical data + Tensor const& A, // (V) Logical data + Tensor const& B, // (V) Logical data + Tensor const& C) // (V) Logical data +{ + // No static assertions on (V), MMA checks compatibility + mma.call(D, A, B, C); +} + +// Dispatch [2]: (M) x (N) => (M,N) +template ::value && + ALayout::rank == 1 && is_rmem::value && + BLayout::rank == 1 && is_rmem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M) Logical data + Tensor const& B, // (N) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + + gemm(mma, + D, // (M,N) + make_tensor(A.data(), append<2>(A.layout())), // (M,1) + make_tensor(B.data(), append<2>(B.layout())), // (N,1) + C); // (M,N) +} + +// Dispatch [3]: (M,K) x (N,K) => (M,N) +template ::value && + ALayout::rank == 2 && is_rmem::value && + BLayout::rank == 2 && is_rmem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M,K) Logical data + Tensor const& B, // (N,K) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + + // Assert this is a 1-value MMA + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); + + gemm(mma, + make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) + make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) + make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K) + make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N) +} + +// Dispatch [4]: (V,M) x (V,N) => (V,M,N) +template ::value && + ALayout::rank == 2 && is_rmem::value && + BLayout::rank == 2 && is_rmem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M) Logical data + Tensor const& B, // (V,N) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + + // REGISTER .reuse OPTIMIZATIONS + + auto M = size<1>(A); + auto N = size<1>(B); + + // 64-bit traversal specialization -- serpentine path + if (size<0>(A) * sizeof(typename Tensor::value_type) == 8 && + size<0>(B) * sizeof(typename Tensor::value_type) == 8) + { +#if 1 // NOTE: Must depend on the C-matrix order... (which we can test) + // Row-major iteration + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate + gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns)); + } + } +#else + // Col-major iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } +#endif + } else + + // 32-bit traversal specialization -- kinked serpentine path + if (size<0>(A) * sizeof(typename Tensor::value_type) == 4 && + size<0>(B) * sizeof(typename Tensor::value_type) == 4) + { +#if 1 // NOTE: Must depend on the C-matrix order... (which we can test) + // Row-major iteration + CUTE_UNROLL + for (int m = 0; m < M; m += 2) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 2) ? N-1-n : n; + gemm(mma, D(_,m+0,ns), A(_,m+0), B(_,ns), C(_,m+0,ns)); + + if (m+1 < M) { + gemm(mma, D(_,m+1,ns), A(_,m+1), B(_,ns), C(_,m+1,ns)); + } + } + } +#else + // Col-major iteration + CUTE_UNROLL + for (int n = 0; n < N; n += 2) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + // Kinked serpentine traversal for maximum register reuse + int ms = (n & 2) ? M-1-m : m; + gemm(mma, D(_,ms,n+0), A(_,ms), B(_,n+0), C(_,ms,n+0)); + + if (n+1 < N) { + gemm(mma, D(_,ms,n+1), A(_,ms), B(_,n+1), C(_,ms,n+1)); + } + } + } +#endif + } else { + // Fallback to serpentine loop + // Col-major iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } + } +} + +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +template ::value && + ALayout::rank == 3 && is_rmem::value && + BLayout::rank == 3 && is_rmem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M,K) Logical data + Tensor const& B, // (V,N,K) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + + auto K = size<2>(A); + + CUTE_UNROLL + for (int k = 0; k < K; ++k) { + gemm(mma, D, A(_,_,k), B(_,_,k), C); + } +} + +// +// Thread-Local Shared-Memory GEMMs +// + +// Dispatch [1]: (V) x (V) => (V) +// Dispatch [2]: (M) x (N) => (M,N) +// Dispatch [3]: (M,K) x (N,K) => (M,N) +// Dispatch [4]: (V,M) x (V,N) => (V,M,N) +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +// Dispatch [3]: (M,K) x (N,K) => (M,N) +template ::value && + ALayout::rank == 2 && is_smem::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M,K) Logical data + Tensor const& B, // (N,K) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + + // Assert this is a 1-value MMA + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); + + gemm(mma, + make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) + make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) + make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K) + make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N) +} + +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +template ::value && + ALayout::rank == 3 && is_smem::value && + BLayout::rank == 3 && is_smem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M,K) Logical data + Tensor const& B, // (V,N,K) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + + auto rA = MMA_Atom::make_fragment_A(A); + auto rB = MMA_Atom::make_fragment_B(B); + + auto K = size<2>(A); + + CUTE_UNROLL + for (int k = 0; k < K; ++k) + { + copy(A(_,_,k), rA(_,_,k)); + copy(B(_,_,k), rB(_,_,k)); + // Thread-level register gemm for k + gemm(mma, D, rA(_,_,k), rB(_,_,k), C); + } +} + +// +// Collective Shared-Memory GEMMs +// + +template ::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_smem::value)> +CUTE_HOST_DEVICE +void +gemm(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor sA, + Tensor sB, + Beta const& beta, + Tensor sC, + ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */, + BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */) +{ + CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + using TypeA = typename TA::value_type; + using TypeB = typename TB::value_type; + using TypeC = typename TC::value_type; + + static_assert(std::is_same_v>, TypeA>, + "ALoadTransformOp functor must accept and return value of type TA::value_type"); + static_assert(std::is_same_v>, TypeB>, + "BLoadTransformOp functor must accept and return value of type TB::value_type"); + + // Original, static size of the problem + auto M = size<0>(sC); + auto N = size<1>(sC); + auto K = size<1>(sA); + + // Block size of the compute tile + auto BLK_M = tile_size<0>(thr_mma); + auto BLK_N = tile_size<1>(thr_mma); + auto BLK_K = tile_size<2>(thr_mma); + + // Compute the "residues" + auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M] + auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N] + auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0] + + // Shift the origin so k_residue is zeroth tile + sA.data() = &sA(0,k_residue); + sB.data() = &sB(0,k_residue); + +#if 0 + if (thread0()) { + printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M)); + printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N)); + printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K)); + } +#endif + + // + // MMA Partitioning + // + + // Round the layout extents up to BLK_X + Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K)); + Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K)); + Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N)); + +#if 0 + if (thread0()) { + print(rounded_sA.layout()); print("\n"); + print(rounded_sB.layout()); print("\n"); + print(rounded_sC.layout()); print("\n"); + } +#endif + + // Partition the sA and sB tiles across the threads for the MMA + Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K) + Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K) + Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N) + // Create register tensors for the MMA to operate on + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K) + Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) + +#if 0 + if (thread0()) { + print(tCsA.layout()); print("\n"); + print(tCsB.layout()); print("\n"); + print(tCsC.layout()); print("\n"); + print(tCrA.layout()); print("\n"); + print(tCrB.layout()); print("\n"); + print(tCrC.layout()); print("\n"); + } +#endif + + // + // PREDICATION + // + + // Allocate the preds for only the MMA-mode of tCsA and tCsB + Tensor tCpA = make_tensor(size<0>(tCsA)); + Tensor tCpB = make_tensor(size<0>(tCsB)); + + // Create coordinate tensors on a single compute block for predication + Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k) + + // Repeat partitioning with thr_mma + Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k) + Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k) + + // Populate the m and n predicates + CUTE_UNROLL + for (int i = 0; i < size(tCpA); ++i) { + tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue); + } + CUTE_UNROLL + for (int i = 0; i < size(tCpB); ++i) { + tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue); + } + +#if 0 + printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n", + threadIdx.x, + int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)), + int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0))); +#endif + + // + // PREFETCH k_block = 0 (with k-predication) + // + + CUTE_UNROLL + for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I + if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k + CUTE_UNROLL + for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m + tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{}; + } + } + } + + CUTE_UNROLL + for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I + if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k + CUTE_UNROLL + for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n + tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{}; + } + } + } + // + // MAINLOOP + // + + // Clear accumulators + clear(tCrC); + + constexpr int K_BLOCK_MAX = size<2>(tCrA); + + CUTE_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) + { + // static-if load the next k_block. No k-predication required on these loads. + if (k_block < K_BLOCK_MAX-1) + { + // Load the next k_block + int k_next = k_block + 1; + + CUTE_UNROLL + for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M + CUTE_UNROLL + for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m + tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{}; + } + } + + CUTE_UNROLL + for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N + CUTE_UNROLL + for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n + tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{}; + } + } + } + + // GEMM on k_block in registers + gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } + + // + // Epilogue + // + + Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n) + + const bool isBetaZero = (beta == Beta{}); + + // Custom axpby_if for now + CUTE_UNROLL + for (int m = 0; m < size<1>(tCsC); ++m) + { + CUTE_UNROLL + for (int n = 0; n < size<2>(tCsC); ++n) + { + CUTE_UNROLL + for (int i = 0; i < size<0>(tCsC); ++i) + { + if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) && + (n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue)) + { + tCsC(i,m,n) = isBetaZero ? alpha * tCrC(i,m,n) : alpha * tCrC(i,m,n) + beta * tCsC(i,m,n); + } + } + } + } +} + +template ::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_smem::value)> +CUTE_HOST_DEVICE +void +gemm(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor sA, + Tensor sB, + Beta const& beta, + Tensor sC) +{ + gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */); +} + +} // end namespace cute diff --git a/include/cute/algorithm/prefer.hpp b/include/cute/algorithm/prefer.hpp new file mode 100644 index 0000000000..700edff0ba --- /dev/null +++ b/include/cute/algorithm/prefer.hpp @@ -0,0 +1,46 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +namespace cute +{ + +// Infinite types that inherit from each other +template +struct prefer : prefer {}; + +template <> +struct prefer<0> {}; + +// Can be used to preferencially overload implementations +// Higher N in prefer have higher priority. + +} // end namespace cute diff --git a/include/cute/algorithm/tensor_algorithms.hpp b/include/cute/algorithm/tensor_algorithms.hpp new file mode 100644 index 0000000000..258ddec680 --- /dev/null +++ b/include/cute/algorithm/tensor_algorithms.hpp @@ -0,0 +1,102 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** Common algorithms on (hierarchical) tensors */ + +#pragma once + +#include + +#include + +namespace cute +{ + +// +// for_each +// + +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor const& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + static_cast(op)(tensor(i)); + } +} + +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + static_cast(op)(tensor(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor&& tensor, UnaryOp&& op) +{ + return for_each(tensor, static_cast(op)); +} + +// +// transform +// + +// Similar to std::transform but does not return number of elements affected +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = static_cast(op)(tensor(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor&& tensor, UnaryOp&& op) +{ + return transform(tensor, std::forward(op)); +} + +} // end namespace cute diff --git a/include/cute/algorithm/tuple_algorithms.hpp b/include/cute/algorithm/tuple_algorithms.hpp new file mode 100644 index 0000000000..35b19f9612 --- /dev/null +++ b/include/cute/algorithm/tuple_algorithms.hpp @@ -0,0 +1,846 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include +#include + +/** Common algorithms on (hierarchical) tuples */ +/** Style choice: + * Forward params [using static_cast(.)] for const/non-const/ref/non-ref args + * but don't bother forwarding functions as ref-qualified member fns are extremely rare + */ + +namespace cute +{ + +// +// Apply (Unpack) +// (t, f) => f(t_0,t_1,...,t_n) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +apply(T&& t, F&& f, seq) +{ + return f(get(static_cast(t))...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +apply(T&& t, F&& f) +{ + return detail::apply(static_cast(t), f, tuple_seq{}); +} + +// +// Transform Apply +// (t, f, g) => g(f(t_0),f(t_1),...) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T&& t, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t)))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T0&& t0, T1&& t1, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t0)), + get(static_cast(t1)))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t0)), + get(static_cast(t1)), + get(static_cast(t2)))...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T&& t, F&& f, G&& g) +{ + return detail::tapply(static_cast(t), f, g, tuple_seq{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T0&& t0, T1&& t1, F&& f, G&& g) +{ + return detail::tapply(static_cast(t0), static_cast(t1), f, g, tuple_seq{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g) +{ + return detail::tapply(static_cast(t0), static_cast(t1), static_cast(t2), f, g, tuple_seq{}); +} + +// +// For Each +// (t, f) => f(t_0),f(t_1),...,f(t_n) +// + +template +CUTE_HOST_DEVICE constexpr +void +for_each(T&& t, F&& f) +{ + detail::apply(t, [&](auto&&... a) { (f(static_cast(a)), ...); }, tuple_seq{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +for_each_leaf(T&& t, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::apply(static_cast(t), [&](auto&&... a){ return (for_each_leaf(static_cast(a), f), ...); }, tuple_seq{}); + } else { + return f(static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Transform +// (t, f) => (f(t_0),f(t_1),...,f(t_n)) +// + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T const& t, F&& f) +{ + return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T0 const& t0, T1 const& t1, F&& f) +{ + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) +{ + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_leaf(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return transform(t, [&](auto const& a) { return transform_leaf(a, f); }); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// find and find_if +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +find_if(T const& t, F&& f, seq<>) +{ + return cute::integral_constant::value>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +find_if(T const& t, F&& f, seq) +{ + if constexpr (decltype(f(get(t)))::value) { + return cute::integral_constant{}; + } else { + return find_if(t, f, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +find_if(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::find_if(t, f, tuple_seq{}); + } else { + return cute::integral_constant{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +find(T const& t, X const& x) +{ + return find_if(t, [&](auto const& v) { return v == x; }); // This should always return a static true/false +} + +template +auto +none_of(T const& t, F&& f) +{ + return cute::integral_constant::value>{}; +} + +template +auto +all_of(T const& t, F&& f) +{ + auto not_f = [&](auto const& a) { return !f(a); }; + return cute::integral_constant::value>{}; +} + +template +auto +any_of(T const& t, F&& f) +{ + return cute::integral_constant{}; +} + +// +// Filter +// (t, f) => +// + +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T const& t, F&& f) +{ + return transform_apply(t, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T0 const& t0, T1 const& t1, F&& f) +{ + return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + +// +// Fold (Reduce, Accumulate) +// (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n) +// + +namespace detail { + +// This impl compiles much faster than cute::apply and variadic args +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +fold(T&& t, V&& v, F&& f, seq<>) +{ + return static_cast(v); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +fold(T&& t, V&& v, F&& f, seq) +{ + if constexpr (sizeof...(Is) == 0) { + return f(static_cast(v), get(static_cast(t))); + } else { + return fold(static_cast(t), + f(static_cast(v), get(static_cast(t))), + f, + seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +fold(T&& t, V&& v, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::fold(static_cast(t), + static_cast(v), + f, + tuple_seq{}); + } else { + return f(static_cast(v), static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +fold_first(T&& t, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::fold(static_cast(t), + get<0>(static_cast(t)), + f, + make_range<1,std::tuple_size>::value>{}); + } else { + return static_cast(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// front, back, take, unwrap +// + +// Get the first non-tuple element in a hierarchical tuple +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +front(T&& t) +{ + if constexpr (is_tuple>::value) { + return front(get<0>(static_cast(t))); + } else { + return static_cast(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Get the last non-tuple element in a hierarchical tuple +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +back(T&& t) +{ + if constexpr (is_tuple>::value) { + constexpr int N = tuple_size>::value; + return back(get(static_cast(t))); + } else { + return static_cast(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Takes the elements in the range [B,E) +template +CUTE_HOST_DEVICE constexpr +auto +take(T const& t) +{ + return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range{}); +} + +// Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple +template +CUTE_HOST_DEVICE constexpr +auto +unwrap(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (tuple_size::value == 1) { + return unwrap(get<0>(t)); + } else { + return t; + } + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Flatten a hierarchical tuple to a tuple of depth one. +// + +template +CUTE_HOST_DEVICE constexpr +auto +flatten_to_tuple(T const& t) +{ + if constexpr (is_tuple::value) { + return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); + } else { + return cute::make_tuple(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +flatten(T const& t) +{ + if constexpr (is_tuple::value) { + return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// insert and remove and replace +// + +namespace detail { + +// Shortcut around tuple_cat for common insert/remove/repeat cases +template +CUTE_HOST_DEVICE constexpr +auto +construct(T const& t, X const& x, seq, seq, seq) +{ + return cute::make_tuple(get(t)..., (void(J),x)..., get(t)...); +} + +} // end namespace detail + +// Insert x into the Nth position of the tuple +template +CUTE_HOST_DEVICE constexpr +auto +insert(T const& t, X const& x) +{ + return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); +} + +// Remove the Nth element of the tuple +template +CUTE_HOST_DEVICE constexpr +auto +remove(T const& t) +{ + return detail::construct(t, 0, make_seq{}, seq<>{}, make_range::value>{}); +} + +// Replace the Nth element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace(T const& t, X const& x) +{ + return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); +} + +// Replace the first element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace_front(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(t, x, seq<>{}, seq<0>{}, make_range<1,tuple_size::value>{}); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// Replace the last element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace_back(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(t, x, make_seq::value-1>{}, seq<0>{}, seq<>{}); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Make a tuple of Xs of tuple_size N +// + +template +CUTE_HOST_DEVICE constexpr +auto +repeat(X const& x) +{ + return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); +} + +// +// Make a tuple of Xs the same profile as tuple +// + +template +CUTE_HOST_DEVICE constexpr +auto +repeat_like(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return transform(t, [&](auto const& a) { return repeat_like(a,x); }); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// Group the elements [B,E) of a T into a single element +// e.g. group<2,4>(T<_1,_2,_3,_4,_5,_6>{}) +// => T<_1,_2,T<_3,_4>,_5,_6>{} +template +CUTE_HOST_DEVICE constexpr +auto +group(T const& t) +{ + return detail::construct(t, take(t), make_seq{}, seq<0>{}, make_range::value>{}); +} + +// +// Extend a T to rank N by appending/prepending an element +// + +template +CUTE_HOST_DEVICE constexpr +auto +append(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + if constexpr (N == tuple_size::value) { + return a; + } else { + static_assert(N > tuple_size::value); + return detail::construct(a, x, make_seq::value>{}, make_seq::value>{}, seq<>{}); + } + } else { + if constexpr (N == 1) { + return a; + } else { + return detail::construct(cute::make_tuple(a), x, seq<0>{}, make_seq{}, seq<>{}); + } + } + + CUTE_GCC_UNREACHABLE; +} +template +CUTE_HOST_DEVICE constexpr +auto +append(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(a, x, make_seq::value>{}, seq<0>{}, seq<>{}); + } else { + return cute::make_tuple(a, x); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + if constexpr (N == tuple_size::value) { + return a; + } else { + static_assert(N > tuple_size::value); + return detail::construct(a, x, seq<>{}, make_seq::value>{}, make_seq::value>{}); + } + } else { + if constexpr (N == 1) { + return a; + } else { + static_assert(N > 1); + return detail::construct(cute::make_tuple(a), x, seq<>{}, make_seq{}, seq<0>{}); + } + } + + CUTE_GCC_UNREACHABLE; +} +template +CUTE_HOST_DEVICE constexpr +auto +prepend(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(a, x, seq<>{}, seq<0>{}, make_seq::value>{}); + } else { + return cute::make_tuple(x, a); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Inclusive scan (prefix sum) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +iscan(T const& t, V const& v, F&& f, seq) +{ + // Apply the function to v and the element at I + auto v_next = f(v, get(t)); + // Replace I with v_next + auto t_next = replace(t, v_next); + +#if 0 + std::cout << "ISCAN i" << I << std::endl; + std::cout << " t " << t << std::endl; + std::cout << " i " << v << std::endl; + std::cout << " f(i,t) " << v_next << std::endl; + std::cout << " t_n " << t_next << std::endl; +#endif + + if constexpr (sizeof...(Is) == 0) { + return t_next; + } else { + return iscan(t_next, v_next, f, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +iscan(T const& t, V const& v, F&& f) +{ + return detail::iscan(t, v, f, tuple_seq{}); +} + +// +// Exclusive scan (prefix sum) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +escan(T const& t, V const& v, F&& f, seq) +{ + if constexpr (sizeof...(Is) == 0) { + // Replace I with v + return replace(t, v); + } else { + // Apply the function to v and the element at I + auto v_next = f(v, get(t)); + // Replace I with v + auto t_next = replace(t, v); + +#if 0 + std::cout << "ESCAN i" << I << std::endl; + std::cout << " t " << t << std::endl; + std::cout << " i " << v << std::endl; + std::cout << " f(i,t) " << v_next << std::endl; + std::cout << " t_n " << t_next << std::endl; +#endif + + // Recurse + return escan(t_next, v_next, f, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +escan(T const& t, V const& v, F&& f) +{ + return detail::escan(t, v, f, tuple_seq{}); +} + +// +// Zip (Transpose) +// + +// Take ((a,b,c,...),(x,y,z,...),...) rank-R0 x rank-R1 input +// to produce ((a,x,...),(b,y,...),(c,z,...),...) rank-R1 x rank-R0 output + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +zip_(T const& t, seq) +{ + return cute::make_tuple(get(get(t))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zip(T const& t, seq, seq) +{ + static_assert(conjunction>::value == tuple_size>::value>...>::value, "Mismatched Ranks"); + return cute::make_tuple(detail::zip_(t, seq{})...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +zip(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple>::value) { + return detail::zip(t, tuple_seq{}, tuple_seq>{}); + } else { + return cute::make_tuple(t); + } + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// Convenient to pass them in separately +template +CUTE_HOST_DEVICE constexpr +auto +zip(T0 const& t0, T1 const& t1, Ts const&... ts) +{ + return zip(cute::make_tuple(t0, t1, ts...)); +} + +// +// zip2_by -- A guided zip for rank-2 tuples +// Take a tuple like ((A,a),((B,b),(C,c)),d) +// and produce a tuple ((A,(B,C)),(a,(b,c),d)) +// where the rank-2 modes are selected by the terminals of the guide (X,(X,X)) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +zip2_by(T const& t, TG const& guide, seq, seq) +{ + // zip2_by produces the modes like ((A,a),(B,b),...) + auto split = cute::make_tuple(zip2_by(get(t), get(guide))...); + + // Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y)) + return cute::make_tuple(cute::make_tuple(get(split)...), + cute::make_tuple(get(split)..., get(t)...)); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +zip2_by(T const& t, TG const& guide) +{ + if constexpr (is_tuple::value) { + constexpr int TR = tuple_size::value; + constexpr int GR = tuple_size::value; + static_assert(TR >= GR, "Mismatched ranks"); + return detail::zip2_by(t, guide, + make_range< 0, GR>{}, + make_range{}); + } else { + static_assert(tuple_size::value == 2, "Mismatched ranks"); + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace cute diff --git a/include/cute/arch/cluster_sm90.hpp b/include/cute/arch/cluster_sm90.hpp new file mode 100644 index 0000000000..6fd9edd382 --- /dev/null +++ b/include/cute/arch/cluster_sm90.hpp @@ -0,0 +1,190 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ + ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)))) +# define CUTE_ARCH_CLUSTER_SM90_ENABLED +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) +# define CUTE_ARCH_ELECT_ONE_SM90_ENABLED +#endif + +namespace cute { + +CUTE_DEVICE void cluster_arrive_relaxed() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : : ); +#else + asm volatile ("brkpt;\n" ::); +#endif +} + +CUTE_DEVICE void cluster_arrive() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.arrive.aligned;\n" : : ); +#else + asm volatile ("brkpt;\n" ::); +#endif +} + +CUTE_DEVICE void cluster_wait() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.wait.aligned;\n" : : ); +#else + asm volatile ("brkpt;\n" ::); +#endif +} + +CUTE_DEVICE void cluster_sync() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + cluster_arrive(); + cluster_wait(); +#else + asm volatile ("brkpt;\n" ::); +#endif +} + +// Returns the dim3 grid size in terms of number of clusters. +CUTE_DEVICE dim3 cluster_grid_dims() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %nclusterid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %nclusterid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %nclusterid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#else + return gridDim; +#endif +} + +// Returns the dim3 cluster rank in the grid. +CUTE_DEVICE dim3 cluster_id_in_grid() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#else + return blockIdx; +#endif +} + +// Returns the relative dim3 block rank local to the cluster. +CUTE_DEVICE dim3 block_id_in_cluster() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %cluster_ctaid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %cluster_ctaid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %cluster_ctaid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#else + return {0,0,0}; +#endif +} + +// Returns the dim3 cluster shape. +CUTE_DEVICE dim3 cluster_shape() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %cluster_nctaid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %cluster_nctaid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %cluster_nctaid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#else + return {1,1,1}; +#endif +} + +// Get 1D ctaid in a cluster. +CUTLASS_DEVICE uint32_t block_rank_in_cluster() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t rank; + asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(rank) :); + return rank; +#else + return 0; +#endif +} + +// Set the destination block-ID in cluster for a given SMEM Address +CUTLASS_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t result; + asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(result) + : "r"(smemAddr), "r"(rank)); + return result; +#else + return smemAddr; +#endif +} + +// Elect one thread in the warp. The elected thread gets its predicate set to true, all others obtain false. +CUTE_HOST_DEVICE uint32_t elect_one_sync() +{ +#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED) + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %rx;\n" + ".reg .pred %px;\n" + " elect.sync %rx|%px, %2;\n" + "@%px mov.s32 %1, 1;\n" + " mov.s32 %0, %rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +#elif defined(__CUDA_ARCH__) + return (threadIdx.x % 32) == 0; +#else + return true; +#endif +} + +} // end namespace cute diff --git a/include/cute/arch/copy.hpp b/include/cute/arch/copy.hpp new file mode 100644 index 0000000000..aa7bb333ed --- /dev/null +++ b/include/cute/arch/copy.hpp @@ -0,0 +1,71 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// +// Direct Copy for any type +// + +template +struct UniversalCopy +{ + using SRegisters = S[1]; + using DRegisters = D[1]; + + CUTE_HOST_DEVICE static constexpr void + copy(S const& src, + D & dst) + { + dst = src; + } +}; + +// +// Placeholder for the copy algorithm's default, auto-vectorizing behavior +// + +struct DefaultCopy +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint128_t[1]; +}; + +using AutoVectorizingCopy = DefaultCopy; + +} // end namespace cute diff --git a/include/cute/arch/copy_sm75.hpp b/include/cute/arch/copy_sm75.hpp new file mode 100644 index 0000000000..fda6340d35 --- /dev/null +++ b/include/cute/arch/copy_sm75.hpp @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) +# define CUTE_ARCH_LDSM_SM75_ENABLED +#endif + +namespace cute +{ + +struct SM75_U32x1_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst) + { +#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(smem_int_ptr)); +#else + CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); +#endif + } +}; + +struct SM75_U32x2_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); +#endif + } +}; + +struct SM75_U32x4_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); +#endif + } +}; + +struct SM75_U16x2_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst) + { +#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(smem_int_ptr)); +#else + CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); +#endif + } +}; + +struct SM75_U16x4_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); +#endif + } +}; + +struct SM75_U16x8_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); +#endif + } +}; + +// +// Legacy LDSM interfaces that aren't very useful +// + +template +CUTE_HOST_DEVICE +void +copy_ldsm(uint128_t const* const smem_ptr, + T* rmem_ptr) +{ + uint32_t* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM75_U32x1_LDSM_N::copy(smem_ptr[0], reg_ptr[0]); + } + else if (sizeof(T) == 8) { + SM75_U32x2_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]); + } + else if (sizeof(T) == 16) { + SM75_U32x4_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +template +CUTE_HOST_DEVICE +void +copy_ldsm_trans(uint128_t const* const smem_ptr, + T* rmem_ptr) +{ + uint32_t* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM75_U16x2_LDSM_T::copy(smem_ptr[0], reg_ptr[0]); + } + else if (sizeof(T) == 8) { + SM75_U16x4_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]); + } + else if (sizeof(T) == 16) { + SM75_U16x8_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +} // end namespace cute diff --git a/include/cute/arch/copy_sm80.hpp b/include/cute/arch/copy_sm80.hpp new file mode 100644 index 0000000000..c6c44121bd --- /dev/null +++ b/include/cute/arch/copy_sm80.hpp @@ -0,0 +1,138 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +# define CUTE_ARCH_CP_ASYNC_SM80_ENABLED +#endif + +namespace cute +{ + +/// Copy via cp.async with caching at all levels +template +struct SM80_CP_ASYNC_CACHEALWAYS +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS))); +#else + CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +/// Copy via cp.async with caching at global level +template +struct SM80_CP_ASYNC_CACHEGLOBAL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS))); +#else + CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +CUTE_HOST_DEVICE +void +cp_async_fence() +{ +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Blocks until all but N previous cp.async.commit_group operations have committed. +template +CUTE_HOST_DEVICE +void +cp_async_wait() +{ +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + if constexpr (N == 0) { + asm volatile("cp.async.wait_all;\n" ::); + } else { + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); + } +#endif +} + +template +CUTE_HOST_DEVICE +void +cp_async_wait(Int) +{ + return cp_async_wait(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/arch/copy_sm90.hpp b/include/cute/arch/copy_sm90.hpp new file mode 100644 index 0000000000..6ac96438c1 --- /dev/null +++ b/include/cute/arch/copy_sm90.hpp @@ -0,0 +1,225 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) +# define CUTE_ARCH_STSM_SM90_ENABLED +# define CUTE_ARCH_TMA_SM90_ENABLED +#endif + +namespace cute +{ + +struct SM90_U32x1_STSM_N +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src, + uint128_t & smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" + :: "r"(smem_int_ptr), + "r"(src)); +#else + CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U32x2_STSM_N +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1)); +#else + CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U32x4_STSM_N +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1), "r"(src2), "r"(src3)); +#else + CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x2_STSM_T +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" + :: "r"(smem_int_ptr), + "r"(src)); +#else + CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x4_STSM_T +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1)); +#else + CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x8_STSM_T +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1), "r"(src2), "r"(src3)); +#else + CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +// +// Legacy STSM interfaces that aren't very useful +// + +template +CUTE_HOST_DEVICE +void +copy_stsm(T const* const rmem_ptr, + uint128_t* const smem_ptr) +{ + uint32_t const* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM90_U32x1_STSM_N::copy(reg_ptr[0], smem_ptr[0]); + } + else if (sizeof(T) == 8) { + SM90_U32x2_STSM_N::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]); + } + else if (sizeof(T) == 16) { + SM90_U32x4_STSM_N::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +template +CUTE_HOST_DEVICE +void +copy_stsm_trans(T const* const rmem_ptr, + uint128_t* const smem_ptr) +{ + uint32_t const* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM90_U16x2_STSM_T::copy(reg_ptr[0], smem_ptr[0]); + } + else if (sizeof(T) == 8) { + SM90_U16x4_STSM_T::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]); + } + else if (sizeof(T) == 16) { + SM90_U16x8_STSM_T::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp new file mode 100644 index 0000000000..ca8320f665 --- /dev/null +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -0,0 +1,194 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#include +#include + +#include +#include +#include // to_Format<[u]intX> +#include // to_Format + +namespace cute +{ + +////////////////////////////////////////////////////////////////////////////////////////////////////// +/// Barriers are 64-bit of user-managed information used in broadly two types syncronization patterns +/// 1) arrive/wait on threads (usage: cp.async and warp-specialized kernels) +/// 2) transaction-based (usage: TMA transaction where a CTA issues one transaction) +////////////////////////////////////////////////////////////////////////////////////////////////////// + +// Initialize barrier present in shared memory +CUTE_HOST_DEVICE +void +initialize_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + int thread_count = 1) // Thread count expected to arrive/wait on this barrier +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile ("mbarrier.init.shared.b64 [%0], %1;\n" + :: "r"(smem_int_ptr), + "r"(thread_count)); +#endif +} + +// Set the number of bytes transfered per transaction +CUTE_HOST_DEVICE +void +set_barrier_transaction_bytes(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + uint32_t bytes) // Number of bytes transfered by per TMA transaction +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile ("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;\n" + :: "r"(smem_int_ptr), + "r"(bytes)); +#endif +} + +// Barrier wait +CUTE_HOST_DEVICE +void +wait_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + int phase_bit) // Current phase bit the barrier waiting to flip +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(smem_int_ptr), + "r"(phase_bit)); + +#endif +} + +// Barrier arrive +CUTE_HOST_DEVICE +void +arrive_barrier(uint64_t& smem_barrier) // 64 bits user-manged barrier in smem +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile( + "{\n" + ".reg .b64 state; \n" + "mbarrier.arrive.shared.b64 state, [%0];\n" + "}\n" + :: "r"(smem_int_ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// TMA Descriptor and utilities +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace TMA { + +enum class SmemSwizzleBits : uint8_t { + DISABLE = 0, + B32 = 1, + B64 = 2, + B128 = 3, +}; + +#if (__CUDACC_VER_MAJOR__ >= 12) + +template +inline CUtensorMapDataType to_CUtensorMapDataType() { + if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else + if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else + if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else + if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_INT32; } else + if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_INT64; } else + if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else + if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else + if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else + if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else + if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else + { static_assert(sizeof(T) < 0, "Unknown TMA Format!"); } +} + +inline CUtensorMapSwizzle to_CUtensorMapSwizzle(SmemSwizzleBits const& t) { + switch (t) { + default: assert(false && "Unknown SmemSwizzleBits!"); + case SmemSwizzleBits::DISABLE: return CU_TENSOR_MAP_SWIZZLE_NONE; + case SmemSwizzleBits::B32: return CU_TENSOR_MAP_SWIZZLE_32B; + case SmemSwizzleBits::B64: return CU_TENSOR_MAP_SWIZZLE_64B; + case SmemSwizzleBits::B128: return CU_TENSOR_MAP_SWIZZLE_128B; + } +} + +#endif // (__CUDACC_VER_MAJOR__ >= 12) +} // end namespace TMA + +#if (__CUDACC_VER_MAJOR__ >= 12) +using TmaDescriptor = CUtensorMap; +#else +using TmaDescriptor = struct { char bytes[128]; }; +#endif +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Initiates a TensorMap Prefetch +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +prefetch_tma_descriptor(TmaDescriptor const* desc_ptr) +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Prefetch TMA Descriptor using generic addressing (i.e. no specific state space: const or param) + asm volatile ( + "prefetch.tensormap [%0];" + : + : "l"(gmem_int_desc) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use TMA Descriptor Prefetch without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +/////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/arch/copy_sm90_tma.hpp b/include/cute/arch/copy_sm90_tma.hpp new file mode 100644 index 0000000000..d6025e4ad8 --- /dev/null +++ b/include/cute/arch/copy_sm90_tma.hpp @@ -0,0 +1,552 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3}], [%2];" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4}], [%2];" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5}], [%2];" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6}], [%2];" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2];" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_LOAD_1D::copy(desc_ptr, smem_mbar, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_LOAD_2D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_LOAD_3D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_LOAD_4D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_LOAD_5D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_1D_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, + void const* const smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%4}], [%2], %3;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_2D_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%4, %5}], [%2], %3;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_3D_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%4, %5, %6}], [%2], %3;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_4D_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%4, %5, %6, %7}], [%2], %3;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_5D_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, + void const* const smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_LOAD_1D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_LOAD_2D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_LOAD_3D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_LOAD_4D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_LOAD_5D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_STORE : Initiates a TMA copy from shared memory to global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_STORE_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_STORE_1D::copy(desc_ptr, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_STORE_2D::copy(desc_ptr, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_STORE_3D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_STORE_4D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_STORE_5D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } +}; + +// Indicate arrival of warp issuing TMA_STORE +CUTE_HOST_DEVICE static void +tma_store_arrive() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile("cp.async.bulk.commit_group;"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +// Wait on prior N (Count) TMA_STORE instructions to complete +template +CUTE_HOST_DEVICE static void +tma_store_wait() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile( + "cp.async.bulk.wait_group.read %0;" + : + : "n"(Count) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/arch/mma.hpp b/include/cute/arch/mma.hpp new file mode 100644 index 0000000000..1c1058fcb9 --- /dev/null +++ b/include/cute/arch/mma.hpp @@ -0,0 +1,64 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +// +// Direct FMA for any type +// + +template +struct UniversalFMA +{ + using DRegisters = D[1]; + using ARegisters = A[1]; + using BRegisters = B[1]; + using CRegisters = C[1]; + + CUTE_HOST_DEVICE static constexpr void + fma(D & d, + A const& a, + B const& b, + C const& c) + { + // Forward to an ADL/cute free function for these types + using cute::fma; + fma(d, a, b, c); + } +}; + +} // end namespace cute diff --git a/include/cute/arch/mma_sm61.hpp b/include/cute/arch/mma_sm61.hpp new file mode 100644 index 0000000000..32a9fbbcb5 --- /dev/null +++ b/include/cute/arch/mma_sm61.hpp @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) +# define CUTE_ARCH_MMA_SM61_ENABLED +#endif + +namespace cute +{ + +struct SM61_DP4A +{ + using DRegisters = int32_t[1]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = int32_t[1]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c) + { +#if defined(CUTE_ARCH_MMA_SM61_ENABLED) + asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM61_DP4A without CUTE_ARCH_MMA_SM61_ENABLED"); +#endif + } +}; + +struct SM61_DP2A +{ + using DRegisters = int32_t[1]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = int32_t[1]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c) + { +#if defined(CUTE_ARCH_MMA_SM61_ENABLED) + asm volatile("dp2a.s32.s32 %0, %1, %2, %3;" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM61_DP2A without CUTE_ARCH_MMA_SM61_ENABLED"); +#endif + } +}; + +} // namespace cute diff --git a/include/cute/arch/mma_sm70.hpp b/include/cute/arch/mma_sm70.hpp new file mode 100644 index 0000000000..139e60041a --- /dev/null +++ b/include/cute/arch/mma_sm70.hpp @@ -0,0 +1,329 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1)) +# define CUTE_ARCH_MMA_SM70_SUPPORTED +# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) +# define CUTE_ARCH_MMA_SM70_ENABLED +# endif +#endif + +namespace cute +{ + +// +// SM70 MMA 884 F16F16F16 +// + +struct SM70_8x8x4_F16F16F16F16_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_TN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_NT +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_NT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_NN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_NN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_TT +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_TT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// SM70 MMA 884 F16F16F32 +// + +struct SM70_8x8x4_F32F16F16F32_TN +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_TN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_NT +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_NT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_NN +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_NN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_TT +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_TT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/arch/mma_sm75.hpp b/include/cute/arch/mma_sm75.hpp new file mode 100644 index 0000000000..20d2b56c0b --- /dev/null +++ b/include/cute/arch/mma_sm75.hpp @@ -0,0 +1,120 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) +# define CUTE_ARCH_MMA_SM75_SUPPORTED +# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) +# define CUTE_ARCH_MMA_SM75_ENABLED +# endif +#endif + +namespace cute +{ + +// +// SM75 MMA 1688 F16F16F32 +// + +struct SM75_16x8x8_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_SM75_ENABLED) + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM75_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM75_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// SM75 MMA 8816 S8S8S32 +// + +struct SM75_8x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM75_ENABLED) + asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32" + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM75_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM75_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/arch/mma_sm80.hpp b/include/cute/arch/mma_sm80.hpp new file mode 100644 index 0000000000..6050500a47 --- /dev/null +++ b/include/cute/arch/mma_sm80.hpp @@ -0,0 +1,2132 @@ + /************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +# define CUTE_ARCH_MMA_SM80_ENABLED +#endif + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F16F16F16F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3}," + "{%4}," + "{%5, %6};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F16F16F16F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32BF16BF16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F32BF16BF16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct SM80_16x8x4_F32TF32TF32F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x4_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32TF32TF32F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x4 TN +struct SM80_8x8x4_F64F64F64F64_TN +{ + using DRegisters = double[2]; + using ARegisters = double[1]; + using BRegisters = double[1]; + using CRegisters = double[2]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, + double const& a0, + double const& b0, + double const& c0, double const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=d"(d0), "=d"(d1) + : "d"(a0), + "d"(b0), + "d"(c0), "d"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +// MMA 8x8x4 TN with Planar Complex multiplication +struct SM80_8x8x4_C64C64C64C64_TN +{ + using DRegisters = complex[2]; + using ARegisters = complex[1]; + using BRegisters = complex[1]; + using CRegisters = complex[2]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex const& a0, + complex const& b0, + complex const& c0, complex const& c1) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + + // d.real() = a.real() * b.real() + c.real(); + SM80_8x8x4_F64F64F64F64_TN::fma( + rd0, rd1, + a0.real(), + b0.real(), + c0.real(), c1.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + SM80_8x8x4_F64F64F64F64_TN::fma( + id0, id1, + a0.imag(), + b0.real(), + c0.imag(), c1.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + SM80_8x8x4_F64F64F64F64_TN::fma( + rd0, rd1, + -a0.imag(), + b0.imag(), + d0.real(), d1.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + SM80_8x8x4_F64F64F64F64_TN::fma( + id0, id1, + a0.real(), + b0.imag(), + d0.imag(), d1.imag()); + } +}; + +// MMA 8x8x4 TN with Gaussian Complex multiplication: +// (a + bi)*(c + di) +// yields +// t0 += a*c +// t1 += b*d +// t2 += (a+b)*(c+d) +// then +// re = t0 - t1 +// im = t2 - t0 - t1 +struct SM80_8x8x4_GC64C64C64GC64_TN +{ + struct GaussComplex { + double t0, t1, t2; + + CUTE_HOST_DEVICE //constexpr + operator complex() const { return complex(t0 - t1, t2 - t0 - t1); } + + CUTE_HOST_DEVICE friend //constexpr + complex operator*(GaussComplex const& a, complex const& b) { return static_cast>(a) * b; } + CUTE_HOST_DEVICE friend //constexpr + complex operator*(complex const& a, GaussComplex const& b) { return b * a; } + + CUTE_HOST_DEVICE friend //constexpr + complex operator+(GaussComplex const& a, complex const& b) { return static_cast>(a) + b; } + CUTE_HOST_DEVICE friend //constexpr + complex operator+(complex const& a, GaussComplex const& b) { return b + a; } + }; + + using DRegisters = GaussComplex[2]; + using ARegisters = complex[1]; + using BRegisters = complex[1]; + using CRegisters = GaussComplex[2]; + + CUTE_HOST_DEVICE static void + fma(GaussComplex & d0, GaussComplex & d1, + complex const& a0, + complex const& b0, + GaussComplex const& c0, GaussComplex const& c1) + { + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t0, d1.t0, + a0.real(), + b0.real(), + c0.t0, c1.t0); + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t1, d1.t1, + a0.imag(), + b0.imag(), + c0.t1, c1.t1); + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t2, d1.t2, + a0.real() + a0.imag(), + b0.real() + b0.imag(), + c0.t2, c1.t2); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8U8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8U8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4S4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4U4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4S4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4U4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x128 TN +struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x128 TN +struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x256 TN +struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x256_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute diff --git a/include/cute/arch/mma_sm90.hpp b/include/cute/arch/mma_sm90.hpp new file mode 100644 index 0000000000..08fe2b2810 --- /dev/null +++ b/include/cute/arch/mma_sm90.hpp @@ -0,0 +1,961 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct SM90_16x8x4_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[2]; + using BRegisters = double[1]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, + double const& b0, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), + "d"(b0), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM90_16x8x8_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[4]; + using BRegisters = double[2]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, double const& a2, double const& a3, + double const& b0, double const& b1, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), "d"(a2), "d"(a3), + "d"(b0), "d"(b1), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM90_16x8x16_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[8]; + using BRegisters = double[4]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, double const& a2, double const& a3, + double const& a4, double const& a5, double const& a6, double const& a7, + double const& b0, double const& b1, double const& b2, double const& b3, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7, %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16, %17, %18, %19};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), "d"(a2), "d"(a3), + "d"(a4), "d"(a5), "d"(a6), "d"(a7), + "d"(b0), "d"(b1), "d"(b2), "d"(b3), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct SM90_16x8x4_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[2]; + using BRegisters = complex[1]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& b0, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + SM90_16x8x4_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), + b0.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + SM90_16x8x4_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), + b0.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + SM90_16x8x4_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), + b0.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + SM90_16x8x4_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), + b0.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM90_16x8x8_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[4]; + using BRegisters = complex[2]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& a2, complex const& a3, + complex const& b0, complex const& b1, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + SM90_16x8x8_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), a2.real(), a3.real(), + b0.real(), b1.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + SM90_16x8x8_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), a2.imag(), a3.imag(), + b0.real(), b1.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + SM90_16x8x8_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), + b0.imag(), b1.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + SM90_16x8x8_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), a2.real(), a3.real(), + b0.imag(), b1.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM90_16x8x16_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[8]; + using BRegisters = complex[4]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& a2, complex const& a3, + complex const& a4, complex const& a5, + complex const& a6, complex const& a7, + complex const& b0, complex const& b1, + complex const& b2, complex const& b3, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + SM90_16x8x16_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), a2.real(), a3.real(), + a4.real(), a5.real(), a6.real(), a7.real(), + b0.real(), b1.real(), b2.real(), b3.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + SM90_16x8x16_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), a2.imag(), a3.imag(), + a4.imag(), a5.imag(), a6.imag(), a7.imag(), + b0.real(), b1.real(), b2.real(), b3.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + SM90_16x8x16_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), + -a4.imag(), -a5.imag(), -a6.imag(), -a7.imag(), + b0.imag(), b1.imag(), b2.imag(), b3.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + SM90_16x8x16_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), a2.real(), a3.real(), + a4.real(), a5.real(), a6.real(), a7.real(), + b0.imag(), b1.imag(), b2.imag(), b3.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { +namespace GMMA { + +template< + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +ss_op_selector() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // FP16 accumulator + if constexpr (std::is_same_v) { + static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); + static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + // Dispatch against the Tile N mode size + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x16_F16F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP32 accumulator + else if constexpr (std::is_same_v) { + + // FP16 inputs + if constexpr (std::is_same_v) { + static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x16_F32F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // BF16 inputs + else if constexpr (std::is_same_v) { + static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x16_F32BF16BF16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // TF32 inputs + else if constexpr (std::is_same_v) { + static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x8_F32TF32TF32_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (std::is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + // ElementA == int8_t && ElementB == int8_t + if constexpr (std::is_same_v && std::is_same_v) { + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32S8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // ElementA == int8_t && ElementB == uint8_t + else if constexpr (std::is_same_v && std::is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32S8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // ElementA == uint8_t && ElementB == int8_t + else if constexpr (std::is_same_v && std::is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32U8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // ElementA == uint8_t && ElementB == uint8_t + else if constexpr (std::is_same_v && std::is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32U8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template< + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +rs_op_selector() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // FP16 accumulator + if constexpr (std::is_same_v) { + static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); + static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + // Dispatch against the Tile N mode size + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x16_F16F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP32 accumulator + else if constexpr (std::is_same_v) { + static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + // FP16 inputs + if constexpr (std::is_same_v) { + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x16_F32F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // BF16 inputs + else if constexpr (std::is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x16_F32BF16BF16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // TF32 inputs + else if constexpr (std::is_same_v) { + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x8_F32TF32TF32_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (std::is_same_v) { + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + // ElementA == int8_t && ElementB == int8_t + if constexpr (std::is_same_v && std::is_same_v) { + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32S8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // ElementA == int8_t && ElementB == uint8_t + else if constexpr (std::is_same_v && std::is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32S8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // ElementA == uint8_t && ElementB == int8_t + else if constexpr (std::is_same_v && std::is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32U8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // ElementA == uint8_t && ElementB == uint8_t + else if constexpr (std::is_same_v && std::is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32U8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} +} // end namespace GMMA +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/arch/mma_sm90_desc.hpp b/include/cute/arch/mma_sm90_desc.hpp new file mode 100644 index 0000000000..abac517044 --- /dev/null +++ b/include/cute/arch/mma_sm90_desc.hpp @@ -0,0 +1,131 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA Descriptor and utilities + +// GMMA enums and utilities +namespace GMMA +{ + +enum class LayoutType : uint8_t { + INTERLEAVE = 0, + B128 = 1, + B64 = 2, + B32 = 3, +}; + +CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) { + switch (t) { + case LayoutType::INTERLEAVE: return "INTERLEAVE"; + case LayoutType::B128: return "B128"; + case LayoutType::B64: return "B64"; + case LayoutType::B32: return "B32"; + } + return nullptr; +} + +// Output operator for all enums in this namespace +CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) { + char const* s = to_string(t); + if (s) { + std::operator<<(os, s); // Explicit call to avoid ambiguity + } else { + os.setstate(std::ios_base::failbit); + } + return os; +} + +} // end namespace GMMA + +union GmmaDescriptor +{ + uint64_t desc_; + uint32_t reg32_[2]; + uint16_t reg16_[4]; + + // Bitfield implementation avoids the need for shifts in assignment + struct { + // start_address, bit [0,14), 4LSB not included + uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // leading dimension byte offset, bit [16,30), 4LSB not included + // For N: This is the stride from the first col to the second col of the 8x2 brick in INTERLEAVED + // Unused for all SWIZZLE_* layouts (and assumed to be 1) + // For T: This is the stride from the first 8 rows to the next 8 rows. + uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // stride dimension byte offset, bit [32,46), 4LSB not included + // For N: This is the stride from the first 8 rows to the next 8 rows. + // For T: This is the stride fro mthe first 8 cols to the next 8 cols. + uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // base_offset, bit [49,52) + // Valid only for SWIZZLE_128B and SWIZZLE_64B + uint8_t : 1, base_offset_ : 3, : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused + // layout type, bit [62,64) + // SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 + uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8) + }; + + // Decay to a uint64_t + CUTE_HOST_DEVICE constexpr + operator uint64_t() const noexcept { return desc_; } + + // Printer + CUTE_HOST_DEVICE friend void print(GmmaDescriptor const& t) + { + printf("GmmaDescriptor: 0x%016lx\n", t.desc_); + printf(" start_addr : 0x%04x\n", t.start_address_); + printf(" leading_off: 0x%04x (%d)\n", t.leading_byte_offset_, t.leading_byte_offset_); + printf(" stride_off : 0x%04x (%d)\n", t.stride_byte_offset_, t.stride_byte_offset_); + printf(" base_offset: 0x%01x\n", t.base_offset_); + printf(" layout_type: 0x%01x (%s)\n", t.layout_type_, to_string(static_cast(t.layout_type_))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/arch/mma_sm90_gmma.hpp b/include/cute/arch/mma_sm90_gmma.hpp new file mode 100644 index 0000000000..25a1d1714a --- /dev/null +++ b/include/cute/arch/mma_sm90_gmma.hpp @@ -0,0 +1,12265 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90_ENABLED +#endif + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Warpgroup sync primitives + +CUTE_HOST_DEVICE +void +warpgroup_arrive() +{ +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory"); +#else + CUTE_RUNTIME_ASSERT("Attempting to use wgmma.fence without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif +} + +template +CUTE_HOST_DEVICE +void +warpgroup_wait() +{ + static_assert(N >= 0 && N <= 7, "_warpgroup.wait {N}; must be in range [0, 7]"); +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +#else + CUTE_RUNTIME_ASSERT("Attempting to use wgmma.wait_group without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif +} + +// Marks the commit point for one or more sized batch of warpgroup MMAs. +CUTE_HOST_DEVICE +void +warpgroup_commit_batch() +{ +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +#else + CUTE_RUNTIME_ASSERT("Attempting to use wgmma.commit_group without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif +} + +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(uint32_t& reg) { + asm volatile("" : "+r"(reg) :: "memory"); +} + +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(float& reg) { + asm volatile("" : "+f"(reg) :: "memory"); +} + +namespace GMMA { + +enum class Major { + K = 0, + MN = 1 +}; + +enum class ScaleOut { + Zero = 0, + One = 1 +}; + +enum class ScaleIn { + Neg = -1, + One = 1 +}; + +} // namespace GMMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " + "{%0, %1}," + " %2," + " %3," + " %4, %5, %6, %7, %8;\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8, %9, %10;\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7, %8, %9, %10;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10, %11, %12;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11, %12, %13, %14;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14, %15, %16;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19, %20, %21, %22;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22, %23, %24;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27, %28, %29, %30;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30, %31, %32;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35, %36, %37, %38;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38, %39, %40;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51, %52, %53, %54;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54, %55, %56;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67, %68, %69, %70;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F16+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70, %71, %72;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7, %8, %9, %10;\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10, %11, %12;\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11, %12, %13, %14;\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14, %15, %16;\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19, %20, %21, %22;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22, %23, %24;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35, %36, %37, %38;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38, %39, %40;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51, %52, %53, %54;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54, %55, %56;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67, %68, %69, %70;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70, %71, %72;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99, %100, %101, %102;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102, %103, %104;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131, %132, %133, %134;\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=F16*F16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134, %135, %136;\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7, %8, %9, %10;\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10, %11, %12;\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11, %12, %13, %14;\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14, %15, %16;\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19, %20, %21, %22;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22, %23, %24;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35, %36, %37, %38;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38, %39, %40;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51, %52, %53, %54;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54, %55, %56;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67, %68, %69, %70;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70, %71, %72;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99, %100, %101, %102;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102, %103, %104;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131, %132, %133, %134;\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=BF16*BF16 +template< + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134, %135, %136;\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7, %8;\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10, %11;\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11, %12;\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14, %15;\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19, %20;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22, %23;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35, %36;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38, %39;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51, %52;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54, %55;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67, %68;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70, %71;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99, %100;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102, %103;\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131, %132;\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x8 TN F32+=TF32*TF32 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134, %135;\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=S8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=S8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=U8*S8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x8x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x16x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13;\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x32x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x64x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x96x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x128x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x192x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101;\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 64x256x32 TN S32+=U8*U8 +template< + GMMA::ScaleOut scaleD = GMMA::ScaleOut::One +> +struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + { +#if defined(CUTE_ARCH_MMA_SM90_ENABLED) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133;\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "n"(int32_t(scaleD))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute diff --git a/include/cute/arch/util.hpp b/include/cute/arch/util.hpp new file mode 100644 index 0000000000..007781f56b --- /dev/null +++ b/include/cute/arch/util.hpp @@ -0,0 +1,178 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#if (! defined (__clang__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) + extern "C" { + // This NVVM intrinsic is subject to change in future versions of CUDA. + // Clients should not call it directly. + CUTE_DEVICE uint32_t __nvvm_get_smem_pointer(void*); + } +#endif + +namespace cute +{ + +/// CUTE helper to cast SMEM pointer to unsigned +CUTE_HOST_DEVICE +uint32_t +cast_smem_ptr_to_uint(void const* const ptr) +{ +// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to +// the previous internal intrinsics if they are available. +#if (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11) + // + // This NVVM intrinsic converts an address in shared memory to a plain + // unsigned integer. This is necessary to pass to shared memory instructions + // in inline PTX. + // + // In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2]. + // + //__device__ size_t __cvta_generic_to_shared(void* ptr); + + /// CUTE helper to get SMEM pointer + return static_cast(__cvta_generic_to_shared(ptr)); + +#elif (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) + + return __nvvm_get_smem_pointer(ptr); + +#elif defined(__CUDA_ARCH__) + + uint32_t smem_ptr; + + asm( + "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" + : "=r"(smem_ptr) : "l"(ptr)); + + return smem_ptr; + +#else + + + (void) ptr; + printf("ERROR: cast_smem_ptr_to_uint not supported but used.\n"); + return 0; + +#endif +} + +// +// Utility for pointer interfaces +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrS&& s, int_sequence, + PtrD&& d, int_sequence) +{ + return fn(s[Is]..., d[Id]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence) +{ + return fn(a[Ia]..., b[Ib]..., c[Ic]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, PtrS&& s, PtrD&& d) +{ + return detail::explode(fn, + s, make_int_sequence{}, + d, make_int_sequence{}); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, PtrA&& a, PtrB&& b, PtrC&& c) +{ + return detail::explode(fn, + a, make_int_sequence{}, + b, make_int_sequence{}, + c, make_int_sequence{}); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, PtrD&& d, PtrA&& a, PtrB&& b, PtrC&& c) +{ + return detail::explode(fn, + d, make_int_sequence{}, + a, make_int_sequence{}, + b, make_int_sequence{}, + c, make_int_sequence{}); +} + +} // end namespace cute diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp new file mode 100644 index 0000000000..2c5d9c557a --- /dev/null +++ b/include/cute/atom/copy_atom.hpp @@ -0,0 +1,671 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#include +#include + +#include + +namespace cute { + +// Generic copy_unpack for any Copy_Traits +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const&, + Tensor const& src, + Tensor & dst) +{ + // Specializations can generalize on these checks + //static_assert(is_smem::value, "Expected smem for this Copy_Traits"); + //static_assert(is_rmem::value, "Expected rmem for this Copy_Traits"); + + using RegistersSrc = typename Operation::SRegisters; + using RegistersDst = typename Operation::DRegisters; + using RegTypeSrc = typename std::remove_extent::type; + using RegTypeDst = typename std::remove_extent::type; + constexpr int RegNumSrc = std::extent::value; + constexpr int RegNumDst = std::extent::value; + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); + + detail::explode(Operation::copy, + rS, make_int_sequence{}, + rD, make_int_sequence{}); +} + + +template +struct Copy_Atom; + +template +struct Copy_Atom : Copy_Atom, T> +{}; + +template +struct Copy_Atom, T> + : Copy_Traits +{ + using Traits = Copy_Traits; + + // Bit and Thr layouts from the Copy_Traits + using ThrID = typename Traits::ThrID; + using BitLayoutSrc = typename Traits::SrcLayout; + using BitLayoutDst = typename Traits::DstLayout; + using BitLayoutRef = typename Traits::RefLayout; + + using ValType = T; + + using ValLayoutSrc = decltype(upcast::value>(BitLayoutSrc{})); + using ValLayoutDst = decltype(upcast::value>(BitLayoutDst{})); + using ValLayoutRef = decltype(upcast::value>(BitLayoutRef{})); + + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType."); + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType."); + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutRef{}) == size(ThrID{}), "CopyOperation is not valid for Ref of ValType."); + + static constexpr int NumValSrc = size<1>(ValLayoutSrc{}); + static constexpr int NumValDst = size<1>(ValLayoutDst{}); + + // Additional Trait parameters/transformations + template + CUTE_HOST_DEVICE + auto + with(TraitsArgs&&... args) const { + auto traits = Traits::with(std::forward(args)...); + return Copy_Atom{traits}; + } + + // Print thread and data layouts for debugging + CUTE_HOST_DEVICE static + void + print_all() + { + print("ThrID: "); print(ThrID{}); print("\n"); + print("BitLayoutSrc: "); print(BitLayoutSrc{}); print("\n"); + print("BitLayoutDst: "); print(BitLayoutDst{}); print("\n"); + print("BitLayoutRef: "); print(BitLayoutRef{}); print("\n"); + print("ValLayoutSrc: "); print(ValLayoutSrc{}); print("\n"); + print("ValLayoutDst: "); print(ValLayoutDst{}); print("\n"); + print("ValLayoutRef: "); print(ValLayoutRef{}); print("\n"); + print("ValueType: %db", sizeof_bits::value); print("\n"); + } + + // + // Tensor call interfaces + // + + // Cast, check, and call + template + CUTE_HOST_DEVICE + void + call(Tensor const& src, + Tensor & dst) const + { + static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); + static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); + + if constexpr (is_constant::value || is_constant::value) { + // Dispatch to unpack for instruction + return copy_unpack(*this, src, dst); + } else { + // Recurse if needed by peeling the tensor mode + return copy(*this, tensor<0>(src), tensor<0>(dst)); + } + } + + // Accept mutable temporaries + template + CUTE_HOST_DEVICE + void + call(Tensor const& src, + Tensor && dst) const + { + return call(src, dst); + } +}; + +// +// A tiling of copy atoms +// + +template coord [Need not be 2D...] + class ShapeTile_MN> // coord space +struct TiledCopy : Copy_Atom +{ + // Layout information from the CopyAtom + using AtomThrID = typename Copy_Atom::ThrID; // thrid -> thr_idx + using AtomLayoutSrc = typename Copy_Atom::ValLayoutSrc; // (thr,val) -> offset + using AtomLayoutDst = typename Copy_Atom::ValLayoutDst; // (thr,val) -> offset + using AtomLayoutRef = typename Copy_Atom::ValLayoutRef; // (thr,val) -> offset + + using AtomNumThr = decltype(size<0>(AtomLayoutRef{})); + using AtomNumVal = decltype(size<1>(AtomLayoutRef{})); + + // Layout information for the TiledCopy + using Tiler_MN = ShapeTile_MN; + using TiledShape_MN = decltype(shape(ShapeTile_MN{})); + using TiledLayout_TV = LayoutCopy_TV; + using TiledNumThr = decltype(size<0>(TiledLayout_TV{})); + using TiledNumVal = decltype(size<1>(TiledLayout_TV{})); + + CUTE_STATIC_ASSERT_V(TiledNumThr{} % AtomNumThr{} == Int<0>{}, "TiledCopy uses too few thrs for selected CopyAtom"); + CUTE_STATIC_ASSERT_V(TiledNumVal{} % AtomNumVal{} == Int<0>{}, "TiledCopy uses too few vals for selected CopyAtom"); + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) + // where + // ThrV: The threads local to a COPY_ATOM Src. + // ThrX: The threads tiled across COPY_ATOMs Src. + // FrgV: The values local to a COPY_ATOM Src. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr static + auto + tidfrg_S(STensor&& stensor) + { + return thrfrg(stensor, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{})); + } + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) + // where + // ThrV: The threads local to a COPY_ATOM Dst. + // ThrX: The threads tiled across COPY_ATOMs Dst. + // FrgV: The values local to a COPY_ATOM Dst. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr static + auto + tidfrg_D(DTensor&& dtensor) + { + return thrfrg(dtensor, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{})); + } + + template + CUTE_HOST_DEVICE constexpr static + auto + thrfrg(Tensor&& tensor, Ref2TrgLayout const& ref2trg) + { + constexpr int R = remove_cvref_t::rank; + static_assert(R >= rank_v, "Rank of tensor to be partitioned too small."); + // Generalize the dimension checks for arbitrary rank + //CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); + //CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); + + // Take the thrs/vals that the atom is interested in + // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID + auto atom_layout_TV = zipped_divide(TiledLayout_TV{}, make_shape(AtomNumThr{}, AtomNumVal{})); + // ((atom_tid,atom_val),(rest_tid,rest_val)) -> (m,n) + + // Transform to the trg layout + auto trg_layout_TV = atom_layout_TV.compose(ref2trg, _); + // ((trg_tid,trg_val),(rest_tid,rest_val)) -> (m,n) + + // Transform the thrs mode from thrid to thr_idx + // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID + auto thrval2mn = coalesce(zip(trg_layout_TV), Shape<_1,Shape<_1,_1>>{}); + // ((trg_tid,rest_tid),(trg_val,rest_val)) -> (m,n) + + /// ================== + + // Tile the tensor for TiledLayout + auto t_tensor = zipped_divide(tensor, Tiler_MN{}); + // ((TileM,TileN,...),(RestM,RestN,...)) + + // Transform the tile mode + auto tv_tensor = t_tensor.compose(thrval2mn, _); + // ((thrid,val),(RM,RN,...)) + + // Unfold and return + return tv_tensor(make_coord(_,_), _); + } + + // retile_S and retile_D assume they are working with the reference layout -- they are the same + template + CUTE_HOST_DEVICE constexpr static + auto + retile(Tensor&& tensor) + { + constexpr int R = remove_cvref_t::rank; + // Assert that AtomLayoutSrc|Dst is identity so we can skip the Ref transformation + + // Assume the first size<0>(tensor) elements are the first val_ids in TiledLayout_TV. + // Then, we only need the shape+layout of those size<0>(tensor) elements in TiledLayout_TV + // and that shape is what we gather from the other modes of tensor + + auto V = size<0>(tensor); + + auto frg_layout_mn = upcast(right_inverse(TiledLayout_TV{}).with_shape(TiledShape_MN{})); + // (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV + + auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{})); + // (atom_vals,rest_vals) -> (v,m,n) + + /// ======= + + // Tile the tensor for TileFrg + auto t_tensor = zipped_divide(tensor, prepend(product_each(shape(frg_layout_mn)), V)); + // ((TileV,TileM,TileN,...),(1,RestM,RestN,...)) + + // Transform the tile mode + auto v_tensor = t_tensor.compose(frg_layout_v, _); + // ((atom_vals,rest_vals),(1,RM,RN,...)) + + // Unfold and return + return v_tensor(_, append(Int<0>{},_)); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutS_MN() + { + // (M,N) -> (M,N) + auto ref_S = make_layout(TiledShape_MN{}); + // (thr_idx,val_idx) -> (M,N) + auto layoutS_TV = tidfrg_S(ref_S); + // (M,K) -> (thr_idx,val_idx) + auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(ref_S)); + + // athrid = (v,m,k) -> thr_idx + auto thrID_S = make_layout(size<0>(TiledLayout_TV{})); + + return cute::make_tuple(layoutS_MK, thrID_S); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutS_TV() + { + // (M,N) -> (M,N) + auto ref_S = make_layout(TiledShape_MN{}); + // (thr_idx,val_idx) -> (M,N) + return tidfrg_S(ref_S)(_,_,Int<0>{}); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutD_MN() + { + // (M,N) -> (M,N) + auto ref_D = make_layout(TiledShape_MN{}); + // (thr_idx,val_idx) -> (M,N) + auto layoutD_TV = tidfrg_D(ref_D); + // (M,K) -> (thr_idx,val_idx) + auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(ref_D)); + + // athrid = (v,m,k) -> thr_idx + auto thrID_D = make_layout(size<0>(TiledLayout_TV{})); + + return cute::make_tuple(layoutD_MK, thrID_D); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutD_TV() + { + // (M,N) -> (M,N) + auto ref_D = make_layout(TiledShape_MN{}); + // (thr_idx,val_idx) -> (M,N) + return tidfrg_D(ref_D)(_,_,Int<0>{}); + } + + template + struct ThrCopy : Copy_Atom + { + ThrIdx thr_idx_; + + CUTE_HOST_DEVICE + ThrCopy(ThrIdx const& thr_idx) : thr_idx_(thr_idx) {} + + template + CUTE_HOST_DEVICE + auto + partition_S(STensor&& stensor) { + //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), + // "Expected ValType for tiling SrcTensor."); + auto thr_tensor = make_tensor(std::forward(stensor).data(), tidfrg_S(stensor.layout())); + return thr_tensor(thr_idx_, _, repeat>(_)); + } + + template + CUTE_HOST_DEVICE + auto + partition_D(DTensor&& dtensor) { + //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), + // "Expected ValType for tiling DstTensor."); + auto thr_tensor = make_tensor(std::forward(dtensor).data(), tidfrg_D(dtensor.layout())); + return thr_tensor(thr_idx_, _, repeat>(_)); + } + + template + CUTE_HOST_DEVICE static + auto + retile_S(STensor&& stensor) { + static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), + "Expected ValType for tiling SrcTensor."); + return make_tensor(std::forward(stensor).data(), TiledCopy::retile(stensor.layout())); + } + + template + CUTE_HOST_DEVICE static + auto + retile_D(DTensor&& dtensor) { + static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), + "Expected ValType for tiling DstTensor."); + return make_tensor(std::forward(dtensor).data(), TiledCopy::retile(dtensor.layout())); + } + }; + + template ::value)> + CUTE_HOST_DEVICE static + auto + get_slice(ThrIdx const& thr_idx) + { + return ThrCopy(thr_idx); + } + + template ::value)> + CUTE_HOST_DEVICE static + auto + get_thread_slice(ThrIdx const& thr_idx) + { + return get_slice(thr_idx); + } +}; + + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_impl(Copy_Atom const& atom, + LayoutCopy_TV const&, + Tile const&) +{ + return TiledCopy, LayoutCopy_TV, Tile>{atom}; +} + +// +// These tile the Copy_Atom as a whole +// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_A(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) +{ + using MNK = typename TiledMMA::TiledShape_MNK; + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), make_shape(size<0>(MNK{}),size<2>(MNK{}))); +} + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_B(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) +{ + using MNK = typename TiledMMA::TiledShape_MNK; + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), make_shape(size<1>(MNK{}),size<2>(MNK{}))); +} + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) +{ + using MNK = typename TiledMMA::TiledShape_MNK; + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), make_shape(size<0>(MNK{}),size<1>(MNK{}))); +} + +template > +CUTE_HOST_DEVICE +auto +make_tiled_copy(Copy_Atom const& copy_atom, + ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx + ValLayout const& val_layout = {}) +{ + constexpr int R = cute::max(rank_v, rank_v); + + auto thr_layout_mn = append(thr_layout, Layout<_1>{}); + auto val_layout_mn = append(val_layout, Layout<_1>{}); + + // Take the raked_products to compute the Layout_MN + auto layout_mn = raked_product(thr_layout_mn, val_layout_mn); + auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); + + //print("thr_layout: "); print(thr_layout_mn); print("\n"); + //print("val_layout: "); print(val_layout_mn); print("\n"); + //print("layout_mn : "); print(layout_mn); print("\n"); + //print("layout_tv : "); print(layout_tv); print("\n"); + + return make_tiled_copy_impl(copy_atom, layout_tv, product_each(shape(layout_mn))); +} + +// Make a TiledCopy out of the copy_atom that matches the Src-Layout of tiled_copy +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_S(Copy_Atom const& copy_atom, + TiledCopy const& tiled_copy) +{ + return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutS_TV(), typename TiledCopy::Tiler_MN{}); +} + +// Make a TiledCopy out of the copy_atom that matches the Dst-Layout of tiled_copy +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_D(Copy_Atom const& copy_atom, + TiledCopy const& tiled_copy) +{ + return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutD_TV(), typename TiledCopy::Tiler_MN{}); +} + +// +// Size +// + +// The logical size of a TileCopy +template +CUTE_HOST_DEVICE constexpr +auto +tile_size(TiledCopy const&) +{ + return size(typename TiledCopy::TiledShape_MN{}); +} + +// The number of threads involved in a TiledCopy +template +CUTE_HOST_DEVICE constexpr +auto +size(TiledCopy const&) +{ + return typename TiledCopy::TiledNumThr{}; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE +auto +print_latex(TiledCopy const& copy) +{ + auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN(); + auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN(); + + print_latex_copy(layoutS_MN, thrID_S, + layoutD_MN, thrID_D); +} + +// MNK Copy Layout to Latex TIKZ -- 8-value color coded by thread +template +CUTE_HOST_DEVICE +void +print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutD const& D, ThrIDD const& TD) // (m,n) -> (tid,vid) and tid -> thr_idx +{ + CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); + + assert(size<0>(S) == size<0>(D)); + assert(size<1>(S) == size<1>(D)); + + char const* latex_header = + "\\documentclass{standalone}\n" + "\\usepackage{tikz}\n" + "\\usetikzlibrary{external}\n" + "\\tikzexternalize\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; + char const* latex_footer = + "\\end{tikzpicture}\n" + "\\end{document}\n"; + + char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", + "{rgb,255:red,175;green,255;blue,175}", + "{rgb,255:red,255;green,255;blue,175}", + "{rgb,255:red,255;green,175;blue,175}", + "{rgb,255:red,210;green,210;blue,255}", + "{rgb,255:red,210;green,255;blue,210}", + "{rgb,255:red,255;green,255;blue,210}", + "{rgb,255:red,255;green,210;blue,210}",}; + + // Header + printf("%% LayoutS: "); print(S); printf("\n"); + printf("%% ThrIDS : "); print(TS); printf("\n"); + printf("%% LayoutD: "); print(D); printf("\n"); + printf("%% ThrIDD : "); print(TD); printf("\n\n"); + + printf(latex_header); + + // S starting at 0,0 + for (int i = 0; i < size<0>(S); ++i) { + for (int j = 0; j < size<1>(S); ++j) { + int thrid = S(i,j) % size(TS); + int val_idx = S(i,j) / size(TS); + int thr_idx = TS(thrid); + + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[thr_idx % 8], + i, j, + thr_idx, val_idx); + } + } + + // D starting at 0,size<1>(S)+3 + for (int i = 0; i < size<0>(D); ++i) { + for (int j = 0; j < size<1>(D); ++j) { + int thrid = D(i,j) % size(TD); + int val_idx = D(i,j) / size(TD); + int thr_idx = TD(thrid); + + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[thr_idx % 8], + i, j + size<1>(S) + 3, + thr_idx, val_idx); + } + } + + // S Labels + for (int i = 0, j = -1; i < size<0>(S); ++i) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); + } + for (int j = 0, i = -1; j < size<1>(S); ++j) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); + } + // D Labels + for (int i = 0, j = size<1>(D); i < size<0>(S); ++i) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i); + } + for (int j = 0, i = -1; j < size<1>(D); ++j) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j); + } + + // Footer + printf(latex_footer); +} + +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +// Config +#if (__CUDACC_VER_MAJOR__ >= 12) +# define CUTE_COPY_ATOM_TMA_SM90_ENABLED +#endif + +#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +#include +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/copy_traits.hpp b/include/cute/atom/copy_traits.hpp new file mode 100644 index 0000000000..83cb05652a --- /dev/null +++ b/include/cute/atom/copy_traits.hpp @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +template +struct Copy_Traits +{ + static_assert(sizeof(CopyOperation) == 0, "Copy_Traits not implemented for this Copy_Operation."); +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0,_0>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride<_0,_0>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm75.hpp b/include/cute/atom/copy_traits_sm75.hpp new file mode 100644 index 0000000000..13eb166e29 --- /dev/null +++ b/include/cute/atom/copy_traits_sm75.hpp @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2>>, + Stride,Stride< _1,_128>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2, _2>>, + Stride,Stride< _1,_128,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2, _4>>, + Stride,Stride< _1,_128,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm80.hpp b/include/cute/atom/copy_traits_sm80.hpp new file mode 100644 index 0000000000..089d19347f --- /dev/null +++ b/include/cute/atom/copy_traits_sm80.hpp @@ -0,0 +1,98 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Element copy selector +template +CUTE_HOST_DEVICE constexpr +auto +select_elementwise_copy(SrcTensor const&, DstTensor const&) +{ + using SrcType = typename SrcTensor::value_type; + using DstType = typename DstTensor::value_type; + +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + if constexpr (is_gmem::value && is_smem::value && + sizeof(SrcType) == sizeof(DstType) && + (sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16)) + { + return SM80_CP_ASYNC_CACHEALWAYS{}; + } else { + return UniversalCopy{}; + } + + CUTE_GCC_UNREACHABLE; +#else + return UniversalCopy{}; +#endif +} + +} diff --git a/include/cute/atom/copy_traits_sm90.hpp b/include/cute/atom/copy_traits_sm90.hpp new file mode 100644 index 0000000000..8c5e843f4e --- /dev/null +++ b/include/cute/atom/copy_traits_sm90.hpp @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp new file mode 100644 index 0000000000..18e22bf604 --- /dev/null +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -0,0 +1,795 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include + +#include + +namespace cute +{ + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD /////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_OP : SM90_TMA_LOAD {}; + +// The executable SM90_TMA_LOAD with tma_desc and tma_mbar +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD arguments + TmaDescriptor const& tma_desc_; + uint64_t& tma_load_mbar_; + + template + CUTE_HOST_DEVICE constexpr + void + copy_unpack_(void const* const dst_ptr, + Coord const& src_coord, seq) const + { +#if 0 + print("THR (%d,%d,%d) BLK (%d,%d,%d)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z); + print(" TMA Coord "); print(src_coord); print("\n"); + print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_), + uint64_t(tma_desc_.size1_), + uint64_t(tma_desc_.size2_), + uint64_t(tma_desc_.size3_))); print("\n"); +#endif + + SM90_TMA_LOAD::copy(&tma_desc_, + tma_load_mbar_, + dst_ptr, + get(src_coord)...); + } + + // This is the copy_unpack dispatch for this Copy_Traits + // Src needs to be a gmem tensor with TmaCoordIterator .data() + // Dst needs to be a smem tensor + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + //static_assert(is_gmem::value, "Expected gmem src for SM90_TMA_LOAD"); // TMA spoofed src tensor + static_assert(is_smem::value, "Expected smem dst for SM90_TMA_LOAD"); + + traits.copy_unpack_(dst.data().get(), src.data().coord_, tuple_seq{}); + } +}; + +// The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar +// Use .with(tma_mbar) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD arguments + TmaDescriptor tma_desc_; + GmemStrides g_stride_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, uint16_t const& multicast_mask = 0) const { + // We accept multicast_mask here to keep the API for both atoms consistent + // assert(multicast_mask == 0); + (void) multicast_mask; + return {tma_desc_, tma_mbar}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value; + return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat(Int<0>{}))), + g_shape, + g_stride_); + } + + // Don't try to execute a copy with SM90_TMA_LOAD before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_MULTICAST_OP : SM90_TMA_LOAD_MULTICAST {}; + +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_MULTICAST arguments + TmaDescriptor const& tma_desc_; + uint64_t& tma_load_mbar_; + uint16_t const& multicast_mask_; + + template + CUTE_HOST_DEVICE constexpr + void + copy_unpack_(void const* const dst_ptr, + Coord const& src_coord, seq) const + { +#if 0 + print("THR (%d,%d,%d) BLK (%d,%d,%d)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z); + print(" TMA Coord "); print(src_coord); print("\n"); + print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_), + uint64_t(tma_desc_.size1_), + uint64_t(tma_desc_.size2_), + uint64_t(tma_desc_.size3_))); print("\n"); +#endif + + SM90_TMA_LOAD_MULTICAST::copy(&tma_desc_, + tma_load_mbar_, + multicast_mask_, + dst_ptr, + get(src_coord)...); + } + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + //static_assert(is_gmem::value, "Expected gmem src for SM90_TMA_LOAD"); // TMA spoofed src tensor + static_assert(is_smem::value, "Expected smem dst for SM90_TMA_LOAD_MULTICAST"); + + traits.copy_unpack_(dst.data().get(), src.data().coord_, tuple_seq{}); + } +}; + +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_MULTICAST arguments + TmaDescriptor tma_desc_; + GmemStrides g_stride_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const { + return {tma_desc_, tma_load_mbar, multicast_mask}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value; + return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat(Int<0>{}))), + g_shape, + g_stride_); + } + + // Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_STORE ////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +// The executable SM90_TMA_STORE with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_STORE arguments + TmaDescriptor tma_desc_; + GmemStrides g_stride_; + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value; + return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat(Int<0>{}))), + g_shape, + g_stride_); + } + + template + CUTE_HOST_DEVICE constexpr + void + copy_unpack_(void const* const src_ptr, + Coord const& dst_coord, seq) const + { +#if 0 + print("THR (%d,%d,%d) BLK (%d,%d,%d)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z); + print(" TMA Coord "); print(dst_coord); print("\n"); + print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_), + uint64_t(tma_desc_.size1_), + uint64_t(tma_desc_.size2_), + uint64_t(tma_desc_.size3_))); print("\n"); +#endif + + SM90_TMA_STORE::copy(&tma_desc_, + src_ptr, + get(dst_coord)...); + } + + // This is the copy_unpack dispatch for this Copy_Traits + // Src needs to be a smem tensor + // Dst needs to be a gmem tensor with TmaCoordIterator .data() + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); + //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor + + traits.copy_unpack_(src.data().get(), dst.data().coord_, tuple_seq{}); + } +}; + +// +// MAKE_TMA_COPY and related +// + +template +TMA::SmemSwizzleBits +get_tma_swizzle_bits(ComposedLayout,Offset,SLayout>) +{ + static_assert(M == 4, "Expected 128b=16B=(2^4)B base swizzle."); + static_assert(S == 3, "Unsupported layout swizzle"); + + switch (B) { + default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3. Unsupported layout swizzle."); + case 3: return TMA::SmemSwizzleBits::B128; + case 2: return TMA::SmemSwizzleBits::B64; + case 1: return TMA::SmemSwizzleBits::B32; + case 0: return TMA::SmemSwizzleBits::DISABLE; + } +} + +template +TMA::SmemSwizzleBits +get_tma_swizzle_bits(Layout) +{ + return TMA::SmemSwizzleBits::DISABLE; +} + +template +auto +get_nonswizzle_layout(ComposedLayout,Offset,SLayout> const& slayout) +{ + return slayout.layout_fn(); +} + +template +auto +get_nonswizzle_layout(Layout const& slayout) +{ + return slayout; +} + +/** Make a CuTe CTA-collective TiledCopy for a TMA operation. + * + * @param CopyOp The target copy operation: SM90_TMA_LOAD, SM90_TMA_LOAD_MULTICAST, SM90_TMA_STORE + * @param gtensor The GMEM Tensor to be involved in the TMA. + * @param slayout The SMEM Layout to be involved in the TMA. + * @param cta_tile The CTA-local tile that each CTA will be tiling GMEM with. + * This is often the blk_shape that is used to tile the GMEM for CTAs: + * local_tile(gtensor, blk_shape, blk_coord) -> CTA-local tile of gtensor + * @param cluster_size When using SM90_TMA_LOAD_MULTICAST, this can be a (static) power-of-2 <= 16 + * defining the multicast size (used to further partition the SMEM) + * Else, static-1 + * + * This code attempts to maximize the TMA box size. It does this by tracing + * the SMEM "vector" -- the inverse of the smem layout -- to find the largest + * contiguous array of smem that can be written to/from global memory given + * the constraints that the TMA instruction imposes. + * + * This is accomplished by assigning "basis" strides to the GMEM to track which + * modes of SMEM map to which modes of GMEM, then reorder the modes of GMEM according + * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. + * + * Examples: + using T = float; + T* gptr = nullptr; + + { + // Simple 2D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 256), GenRowMajor{}); // K-Major GMEM + auto slayout = make_layout(make_shape(_64{}, _32{}), GenRowMajor{}); // K-Major SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // GMMA 2D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 256)); // MN-Major GMEM + auto slayout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(_128{},_64{})); // MN-Major Swizzled+Tiled 128x64 SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // 3D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 32, 512), make_stride(64, Int<1>{}, 65536)); // GMEM + auto slayout = make_layout(make_shape(_16{}, _8{}, _2{}), make_stride(_16{}, _1{}, _8{})); // SMEM w/ same major-mode + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // cuTENSOR 4D + auto layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM + auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling: + // Take 128-elem from m: m0 must divide 128, + // m-last may be predicated + // Take 32-elem from k0, 2-elem from k1 + auto slayout = make_layout(cta_tile); // Col-Major SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout, cta_tile, Int<1>{}); + } + * + * Check the TMA box size and desc: + print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + print("TMA desc : "); print(tma.tma_desc_); print("\n"); + * + * Usage: + Tensor mA = tma_a.get_tma_tensor(make_shape(M,N)); // (M,N) TMA coord tensor + Tensor gA = local_tile(mA, cta_tile, cta_coord); // (BLK_M,BLK_N) TMA coord tensor for this CTA + Tensor sA = make_tensor(make_smem_ptr(sptr), slayout); // (BLK_M,BLK_N) SMEM tensor + + auto cta_tma = tma.get_slice(cta_idx_in_cluster); // Slice for multicast partitioning + Tensor tAgA = cta_tma.partition_S(gA); // Partition for src + Tensor tAsA = cta_tma.partition_D(sA); // Partition for dst + + copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params + */ +template +CUTE_HOST +auto +make_tma_copy(CopyOp, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tile const& cta_tile, + Cluster_Size const& cluster_size) +{ + static_assert((std::is_same::value && is_constant<1, Cluster_Size>::value) || + (std::is_same::value) || + (std::is_same::value && is_constant<1, Cluster_Size>::value)); + + using T = typename Tensor::value_type; + + // + // TMA parameter checking + // + + auto flat_glayout = flatten(gtensor.layout()); + + CUTE_STATIC_ASSERT_V(rank(flatten(cta_tile)) <= Int<5>{}, + "CTA_Tile cannot have more than five modes, TMA arch restriction."); + CUTE_STATIC_ASSERT_V(rank(flat_glayout) <= Int<5>{} || rank(flatten(cta_tile)) <= Int<4>{}, + "If GTensor has more than five modes, then CTA_Tile cannot have more than four modes. TMA multimode."); + CUTE_STATIC_ASSERT_V(compatible(product_each(shape(slayout)), shape(cta_tile)), + "CTA_Tile must be compatible with SLayout."); + CUTE_STATIC_ASSERT_V(is_integral{} && has_single_bit(cluster_size) && cluster_size <= Int<16>{}, + "Expecting a pow2 integral Cluster_Size leq 16."); + CUTE_STATIC_ASSERT_V(size(slayout) % cluster_size == Int<0>{}, + "ClusterShape must divide domain size of slayout."); + + // + // TMA slayout manipulation + // + + auto tma_multimode = rank(flat_glayout) > Int<5>{}; + + // Invert the smem to get the largest contiguous vector in the smem layout + auto inv_smem_layout = right_inverse(get_nonswizzle_layout(slayout)); + // trunc_smem_idx -> trunc_smem_coord + + // Map from smem idx to a gmem mode + auto sidx_to_gmode = flatten(composition(make_identity_layout(cta_tile), inv_smem_layout)); + + // Truncate any incompatibilities + auto smem_rank = find_if(stride(sidx_to_gmode), [](auto e){ + [[maybe_unused]] auto v = basis_value(e); + return not is_constant<1,decltype(v)>{}; + }); + static_assert(smem_rank > 0, "Could not find a common smem-gmem vectorization for TMA."); + constexpr int smem_tma_rank = cute::min(int(smem_rank), (tma_multimode ? 4 : 5)); + + // Keep only the static-1 basis modes into gmem + auto sidx_to_gmode_cluster_trunc = take<0,smem_tma_rank>(sidx_to_gmode); + // Keep only the portion each multicast CTA will be responsible for + auto sidx_to_gmode_cta_trunc = composition(sidx_to_gmode_cluster_trunc, shape_div(size(sidx_to_gmode_cluster_trunc), cluster_size)); + + // + // TMA gtensor manipulation + // + + // Generate a TupleBasis for the gtensor + auto flat_gbasis = make_basis_like(shape(flat_glayout)); + + // Fold the flat_gbasis into the glayout + auto glayout_basis = make_layout(shape(gtensor), + stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), flat_gbasis), + make_layout(repeat_like(shape(gtensor), Int<2>{}))))); + + // Tile the modes of gtensor with cta_tile + auto cta_glayout_basis = composition(glayout_basis, cta_tile); + + // Check that the cta_tile selects modes from gtensor properly + for_each(flatten(stride(cta_glayout_basis)), [](auto d) { + static_assert(is_constant<1, decltype(d.value())>::value, + "CTA_Tile does not faithfully partition the GMEM, it should select the number of elements from each mode of glayout."); + }); + + // Tile the modes of gtensor again with the truncated cta_tile o inv_smem_layout + auto tma_layout_cta_trunc = flatten(composition(glayout_basis, sidx_to_gmode_cta_trunc)); + + // Append any missing basis on the end as size-1 modes b/c they got truncated + auto missing_basis = fold(stride(tma_layout_cta_trunc), flat_gbasis, [](auto init, auto e){ + auto k = find(init, e); + return remove(init); + }); + + // The appended map from truncated smem codomain to gmem mode: trunc_smem_idx -> gmem_mode + auto tma_layout_cta = flatten(make_layout(tma_layout_cta_trunc, + make_layout(repeat(Int<1>{}), missing_basis))); + +#if 0 + print("g_layout : "); print(gtensor.layout()); print("\n"); + print("s_layout : "); print(slayout); print("\n"); + print("cta_tile : "); print(cta_tile); print("\n"); + print("cluster_size : "); print(cluster_size); print("\n"); + print("flat_gbasis : "); print(flat_gbasis); print("\n"); + print("cta_glayout : "); print(cta_glayout_basis); print("\n"); + print("inv_smem : "); print(inv_smem_layout); print("\n"); + print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); + print("missing_b : "); print(missing_basis); print("\n"); + print("tma_layout_cta: "); print(tma_layout_cta); print("\n"); +#endif + + // + // TMA gmem desc info + // + + constexpr int TmaRANK = cute::min(rank(flat_glayout), 5); + void* gmem_address = (void*) gtensor.data(); + + cute::array gmem_prob_shape = {1,1,1,1,1}; + cute::array gmem_prob_stride = {0,0,0,0,0}; + for_each(make_seq{}, [&](auto i) { + // NOTE : WAR g++-7.3.5, let it deduce e rather than fuse with below + auto e = stride(tma_layout_cta); + constexpr int j = decltype(e.mode())::value; + constexpr int tma_i = i < 5 ? i : 4; + + // Problem stride + uint64_t stride_j = stride(flat_glayout) * sizeof(T); + uint64_t old_stride = gmem_prob_stride[tma_i]; + gmem_prob_stride[tma_i] = gcd(gmem_prob_stride[tma_i], stride_j); + + // Problem shape + uint64_t shape_j = shape(flat_glayout); + if (gmem_prob_stride[tma_i] != 0) { + // We're "resetting" this TMA mode and using it as a "multimode" + // Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1 + gmem_prob_shape[tma_i] = (gmem_prob_shape[tma_i]-1) * (old_stride / gmem_prob_stride[tma_i]) + + (shape_j-1) * (stride_j / gmem_prob_stride[tma_i]) + + 1; + } else { + gmem_prob_shape[tma_i] = shape_j; + } + }); + + assert((reinterpret_cast(gmem_address) & 0b1111) == 0); // Address must be 16B-aligned + + assert(gmem_prob_shape[0] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[0] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[1] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[1] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[2] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[2] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[3] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[3] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[4] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32 + + assert((gmem_prob_stride[0]) == sizeof(T)); // First stride is implicitly 1 + assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[2]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[2] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[3]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[3] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[4]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[4] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + + // + // TMA smem desc info + // + + // TMA smem box size + cute::array smem_box_shape = {1,1,1,1,1}; + for_each(make_seq{}, [&](auto i) { + uint32_t shape_i = shape(tma_layout_cta); + constexpr int tma_i = i < 5 ? i : 4; + if (tma_multimode && tma_i == 4) { + // We're "reusing" this TMA mode and using it as a "multimode" + smem_box_shape[tma_i] = 1; + } else { + smem_box_shape[tma_i] = shape_i; + } + }); + + // TMA smem mode strides + [[maybe_unused]] cute::array smem_box_stride = {1,1,1,1,1}; + + assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 + assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 + assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 + assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 + assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 + assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 + assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 + assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 + + assert(smem_box_stride[0] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[0] <= (uint32_t(8))); // Stride must be max 2^3 + assert(smem_box_stride[1] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[1] <= (uint32_t(8))); // Stride must be max 2^3 + assert(smem_box_stride[2] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[2] <= (uint32_t(8))); // Stride must be max 2^3 + assert(smem_box_stride[3] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[3] <= (uint32_t(8))); // Stride must be max 2^3 + assert(smem_box_stride[4] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[4] <= (uint32_t(8))); // Stride must be max 2^3 + + // + // Construct the descriptor + // + + TmaDescriptor tma_desc = {0}; + +#if (__CUDACC_VER_MAJOR__ >= 12) + + // + // TMA general info + // + + cuuint32_t tma_dim = TmaRANK; + CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); + CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; + CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + // TMA smem swizzle type + CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(slayout)); + + CUresult result = cuTensorMapEncodeTiled( + &tma_desc, + tma_format, + tma_dim, + gmem_address, + gmem_prob_shape.data(), + gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly 1 + smem_box_shape.data(), + smem_box_stride.data(), + tma_interleave, + smem_swizzle, + tma_l2Promotion, + tma_oobFill); + + if (result != CUDA_SUCCESS) { + std::cerr << "TMA Desc Addr: " << &tma_desc + << "\nformat " << tma_format + << "\ndim " << tma_dim + << "\ngmem_address " << gmem_address + << "\nglobalDim " << gmem_prob_shape + << "\nglobalStrides " << gmem_prob_stride + << "\nboxDim " << smem_box_shape + << "\nelementStrides " << smem_box_stride + << "\ninterleave " << tma_interleave + << "\nswizzle " << smem_swizzle + << "\nl2Promotion " << tma_l2Promotion + << "\noobFill " << tma_oobFill << std::endl; + std::cerr << "Error: Failed to intialize the TMA descriptor " << result << std::endl; + assert(false); + } +#endif // (__CUDACC_VER_MAJOR__ >= 12) + + // + // Construct the Copy_Traits + // + + // Finally, get the inverse permutation of the E bases for the mocked gmem stride + auto gmem_stride_bases_flat = transform(make_seq{}, [&](auto i) { + auto k = find(stride(tma_layout_cta), E{}); + // NOTE: gcc 7.3.5 WAR -- avoid if constexpr + int32_t tma_coord_stride = int32_t(stride(flat_glayout) * sizeof(T) / (gmem_prob_stride[4] != 0 ? gmem_prob_stride[4] : 16)); + return conditional_return(tma_multimode && (k >= Int<4>{}), + E<4>{} * tma_coord_stride, // The 4th TMA mode is the multimode, use int32_t coord stride + E{}); + }); + + // Give that the profile of gtensor and fold it + auto gmem_stride_bases = stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), gmem_stride_bases_flat), + make_layout(repeat_like(shape(gtensor), Int<2>{})))); + + constexpr int num_bits = size(sidx_to_gmode_cta_trunc) * sizeof(T) * 8; + using Traits = Copy_Traits, decltype(gmem_stride_bases)>; + +#if 0 + print("num_bits : "); print(num_bits); print("\n"); + print("g_stride_bases: "); print(gmem_stride_bases); print("\n"); +#endif + + // + // Construct the TiledCopy + // + + // The ThrVal layout for 1 TMA instruction within cta_tile + auto layout_tv_1 = composition(inv_smem_layout, make_layout(make_shape(cluster_size, size(sidx_to_gmode_cta_trunc)), GenRowMajor{})); + // The ThrVal layout for N TMA instructions within cta_tile + auto layout_tv = tile_to_shape(layout_tv_1, make_shape(cluster_size, size(cta_tile)/cluster_size)); + +#if 0 + print("layout_tv : "); print(layout_tv); print("\n"); +#endif + + return TiledCopy, decltype(layout_tv), decltype(cta_tile)>{tma_desc, gmem_stride_bases}; +} + +// Explicit defaulting +template +CUTE_HOST +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout) +{ + return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), Int<1>{}); +} + +template +CUTE_HOST +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Cluster_Size const& cluster_size) +{ + return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), cluster_size); +} + +} // end namespace cute diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp new file mode 100644 index 0000000000..c3025f5065 --- /dev/null +++ b/include/cute/atom/mma_atom.hpp @@ -0,0 +1,1081 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#include +#include +#include +#include + +namespace cute { + +// Generic mma_unpack for any MMA_Traits +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(MMA_Traits const&, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + + // Register value types from the MMA_Operation register arrays + using RegTypeD = typename std::remove_extent::type; + using RegTypeA = typename std::remove_extent::type; + using RegTypeB = typename std::remove_extent::type; + using RegTypeC = typename std::remove_extent::type; + constexpr int RegNumD = std::extent::value; + constexpr int RegNumA = std::extent::value; + constexpr int RegNumB = std::extent::value; + constexpr int RegNumC = std::extent::value; + + Tensor rA = recast(A); + Tensor rB = recast(B); + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + + if constexpr (std::is_same::value) + { + static_assert(std::is_same::value, "GMMA C and D value_type must match."); + static_assert(std::is_same::value, "GMMA C and D layouts must match."); + // assert((void*)&C == (void*)&D); + + Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D + + //CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(Operation::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); + } else + { + Tensor rD = recast(D); + Tensor rC = recast(C); + + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(Operation::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); + } +} + + +namespace detail { + +template +struct FrgTypeA_or_Default { using type = typename X::ElementAVal; }; +template +struct FrgTypeA_or_Default> { using type = typename X::ElementAFrg; }; + +template +struct FrgTypeB_or_Default { using type = typename X::ElementBVal; }; +template +struct FrgTypeB_or_Default> { using type = typename X::ElementBFrg; }; + +template +struct FrgTypeC_or_Default { using type = typename X::ElementCVal; }; +template +struct FrgTypeC_or_Default> { using type = typename X::ElementCFrg; }; + +} // end namespace detail + +template +struct MMA_Atom; + +template +struct MMA_Atom : MMA_Atom> +{}; + +template +struct MMA_Atom> + : MMA_Traits +{ + using Traits = MMA_Traits; + + // Element value types from the MMA_Traits + using ValTypeD = typename Traits::ElementDVal; + using ValTypeA = typename Traits::ElementAVal; + using ValTypeB = typename Traits::ElementBVal; + using ValTypeC = typename Traits::ElementCVal; + + // Thr-Val layouts from the MMA_Traits + using Shape_MNK = typename Traits::Shape_MNK; + using ThrID = typename Traits::ThrID; + using LayoutC_TV = typename Traits::CLayout; + using LayoutA_TV = typename Traits::ALayout; + using LayoutB_TV = typename Traits::BLayout; + + // Fragment value types from the MMA_Traits (optional, defaults to Val type) + using FrgTypeD = typename detail::FrgTypeC_or_Default::type; + using FrgTypeA = typename detail::FrgTypeA_or_Default::type; + using FrgTypeB = typename detail::FrgTypeB_or_Default::type; + using FrgTypeC = typename detail::FrgTypeC_or_Default::type; + + // Additional Trait parameters/transformations + template + CUTE_HOST_DEVICE + auto + with(TraitsArgs&&... args) const { + auto traits = Traits::with(std::forward(args)...); + return MMA_Atom{traits}; + } + + // Print thread and data layouts for debugging + CUTE_HOST_DEVICE static + void + print_all() + { + print("ThrID: "); print(ThrID{}); print("\n"); + print("LayoutA_TV: "); print(LayoutA_TV{}); print("\n"); + print("LayoutB_TV: "); print(LayoutB_TV{}); print("\n"); + print("LayoutC_TV: "); print(LayoutC_TV{}); print("\n"); + } + + // + // Tensor call interfaces + // + + // Cast, check, and call fma + template + CUTE_HOST_DEVICE constexpr + void + call(Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) const + { + static_assert(DLayout::rank == 1, "Expected rank-1 D tensor"); + static_assert(ALayout::rank == 1, "Expected rank-1 A tensor"); + static_assert(BLayout::rank == 1, "Expected rank-1 B tensor"); + static_assert(CLayout::rank == 1, "Expected rank-1 C tensor"); + + return mma_unpack(*this, D, A, B, C); + } + + // Three arguments reproduces C + template + CUTE_HOST_DEVICE constexpr + void + call(Tensor const& A, + Tensor const& B, + Tensor & C) const + { + return call(C, A, B, C); + } + + // + // make_fragment_A|B|C + // These functions are awkward as they expect already-partitioned tensors + // resulting from a previous call to partition_A|B|C + // The reasoning is that we can inspect the layout of the partitioned data + // and attempt to match it in generated fragment to promote vectorization + // when copying from partition to fragment. + // + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_C(CTensor&& ctensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<3>{}); // VMN + CUTE_STATIC_ASSERT_V(size<0>(ctensor) == size<1>(LayoutC_TV{})); + + // C is a bit special because we are after accumulators here + // The input/output type doesn't have to match the accumulator type + //static_assert(std::is_same::value_type>::value, "Expecting ValTypeC type"); + + // We'll never base the accumulator layout on the input tensor layout, so just return a FrgTypeC tensor + return make_tensor(shape(ctensor)); + } + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_A(ATensor&& atensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<3>{}); // VMK + CUTE_STATIC_ASSERT_V(size<0>(atensor) == size<1>(LayoutA_TV{})); + static_assert(std::is_same::value_type>::value, "Expecting ValTypeA type"); + + if constexpr (has_dereference::value) { + return recast(std::forward(atensor)); + } else { + return make_tensor(make_fragment_like(atensor.layout())); + } + + CUTE_GCC_UNREACHABLE; + } + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_B(BTensor&& btensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<3>{}); // VNK + CUTE_STATIC_ASSERT_V(size<0>(btensor) == size<1>(LayoutB_TV{})); + static_assert(std::is_same::value_type>::value, "Expecting ValTypeB type"); + + if constexpr (has_dereference::value) { + return recast(std::forward(btensor)); + } else { + return make_tensor(make_fragment_like(btensor.layout())); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +// +// A tiling of mma atoms +// + +template +struct ThrMMA; + +template >, + class ValLayoutMNK = Layout>, + class PermutationsMNK = Tile> +struct TiledMMA : MMA_Atom +{ + static_assert(rank_v == 3, "TiledMMA requires rank-3 AtomLayoutMNK"); + static_assert(rank_v == 3, "TiledMMA requires rank-3 ValLayoutMNK"); + static_assert(rank_v == 3, "TiledMMA requires rank-3 PermutationsMNK"); + + using AtomShape_MNK = typename MMA_Atom::Shape_MNK; + + using AtomLayoutC_TV = typename MMA_Atom::LayoutC_TV; + using AtomLayoutA_TV = typename MMA_Atom::LayoutA_TV; + using AtomLayoutB_TV = typename MMA_Atom::LayoutB_TV; + + // ThrV -> thread_idx + using AtomThrID = typename MMA_Atom::ThrID; + + // (M,N,K) + using TiledShape_MNK = decltype(make_shape(size<0>(AtomShape_MNK{})*size<0>(AtomLayoutMNK{})*size<0>(ValLayoutMNK{}), + size<1>(AtomShape_MNK{})*size<1>(AtomLayoutMNK{})*size<1>(ValLayoutMNK{}), + size<2>(AtomShape_MNK{})*size<2>(AtomLayoutMNK{})*size<2>(ValLayoutMNK{}))); + + // thrid = (ThrV,ThrM,ThrN,ThrK) -> thr_idx + using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{})); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + using TidLayout = decltype(right_inverse(ThrLayoutVMNK{})); + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrM: The threads tiled in M. layout<1>(ThrLayoutVMNK): ThrM -> thread_idx + // ThrN: The threads tiled in N. layout<2>(ThrLayoutVMNK): ThrN -> thread_idx + // FrgV: The values local to an MMA. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr static + auto + thrfrg_C(CTensor&& ctensor) + { + CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<2>{}); + CUTE_STATIC_ASSERT_V(size<0>(ctensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); + CUTE_STATIC_ASSERT_V(size<1>(ctensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(left_inverse(get<0>(PermutationsMNK{})), + left_inverse(get<1>(PermutationsMNK{}))); + auto t_tensor = logical_divide(ctensor, t_tile); // (PermM,PermN) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<1>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomN),(RestM,RestN)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN)) + + // Tile the tensor for the C-threads + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(ThrLayoutVMNK{})), + make_layout(size<2>(ThrLayoutVMNK{})))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN))) + + return thr_tensor; + } + + // Tile from (M,N,...) + // to (thr_idx,(FrgV,(RestM,RestN,...))) + template + CUTE_HOST_DEVICE constexpr static + auto + tidfrg_C(CTensor&& ctensor) + { + // Don't need a ctile composition because ThrK is last mode in TidLayout + + return thrfrg_C(ctensor).compose(TidLayout{}, _); + } + + // Tile a tensor or a layout from shape + // (M,K,...) + // to shape + // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrM: The threads tiled in M. layout<1>(ThrLayoutVMNK): ThrM -> thread_idx + // ThrK: The threads tiled in K. layout<3>(ThrLayoutVMNK): ThrK -> thread_idx + // FrgV: The values local to an MMA. + // RestM: The values tiled in M. + // RestK: The values tiled in K. + template + CUTE_HOST_DEVICE constexpr static + auto + thrfrg_A(ATensor&& atensor) + { + CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<2>{}); + CUTE_STATIC_ASSERT_V(size<0>(atensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); + CUTE_STATIC_ASSERT_V(size<1>(atensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(left_inverse(get<0>(PermutationsMNK{})), + left_inverse(get<2>(PermutationsMNK{}))); + auto t_tensor = logical_divide(atensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutA_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(ThrLayoutVMNK{})), + make_layout(size<3>(ThrLayoutVMNK{})))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + // Tile from (M,K,...) + // to (thr_idx,(FrgV,(RestM,RestK,...))) + template + CUTE_HOST_DEVICE constexpr static + auto + tidfrg_A(ATensor&& atensor) + { + auto atile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(ThrLayoutVMNK{}), size<2>(ThrLayoutVMNK{})), + make_stride( Int<1>{} , Int<0>{} )), + _)); + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + + return thrfrg_A(atensor).compose(atile, _).compose(TidLayout{}, _); + } + + // Tile a tensor or a layout from shape + // (N,K,...) + // to shape + // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrN: The threads tiled in N. layout<2>(ThrLayoutVMNK): ThrN -> thread_idx + // ThrK: The threads tiled in K. layout<3>(ThrLayoutVMNK): ThrK -> thread_idx + // FrgV: The values local to an MMA. + // RestN: The values tiled in N. + // RestK: The values tiled in K. + template + CUTE_HOST_DEVICE constexpr static + auto + thrfrg_B(BTensor&& btensor) + { + CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<2>{}); + CUTE_STATIC_ASSERT_V(size<0>(btensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); + CUTE_STATIC_ASSERT_V(size<1>(btensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(left_inverse(get<1>(PermutationsMNK{})), + left_inverse(get<2>(PermutationsMNK{}))); + auto t_tensor = logical_divide(btensor, t_tile); // (PermN,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomN,AtomK),(RestN,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<2>(ThrLayoutVMNK{})), + make_layout(size<3>(ThrLayoutVMNK{})))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) + + return thr_tensor; + } + + // Tile from (N,K,...) + // to (thr_idx,(FrgV,(RestN,RestK,...))) + template + CUTE_HOST_DEVICE constexpr static + auto + tidfrg_B(BTensor&& btensor) + { + auto btile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(ThrLayoutVMNK{}), size<2>(ThrLayoutVMNK{})), + make_stride( Int<0>{} , Int<1>{} )), + _)); + // (ThrV,(ThrN,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + + return thrfrg_B(btensor).compose(btile, _).compose(TidLayout{}, _); + } + + template ::value)> + CUTE_HOST_DEVICE static constexpr + auto + get_slice(ThrIdx const& thr_idx) + { + auto thr_vmnk = ThrLayoutVMNK{}.get_flat_coord(thr_idx); + return ThrMMA(thr_vmnk); + } + + template ::value)> + CUTE_HOST_DEVICE static constexpr + auto + get_thread_slice(ThrIdx const& thr_idx) + { + return get_slice(thr_idx); + } + + // + // Utility for printing and visualization + // + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutC_MN() + { + // (M,N) -> (M,N) + auto ref_C = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<1>(TiledShape_MNK{}))); + // (cthrid,val) -> (M,N) + auto layoutC_TV = thrfrg_C(ref_C); + // (M,N) -> (cthrid,frg) + auto layoutC_MN = right_inverse(layoutC_TV).with_shape(shape(ref_C)); + + // cthrid = (v,m,n) -> thr_idx + auto thrID_C = ThrLayoutVMNK{}(_,_,_,Int<0>{}); + + return cute::make_tuple(layoutC_MN, thrID_C); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutC_TV() + { + // (M,N) -> (M,N) + auto ref_C = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<1>(TiledShape_MNK{}))); + + return tidfrg_C(ref_C); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutA_MK() + { + // (M,K) -> (M,K) + auto ref_A = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); + // (athrid,val) -> (M,K) + auto layoutA_TV = thrfrg_A(ref_A); + // (M,K) -> (athrid,frg) + auto layoutA_MK = right_inverse(layoutA_TV).with_shape(shape(ref_A)); + + // athrid = (v,m,k) -> thr_idx + auto thrID_A = ThrLayoutVMNK{}(_,_,Int<0>{},_); + + return cute::make_tuple(layoutA_MK, thrID_A); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutA_TV() + { + // (M,K) -> (M,K) + auto ref_A = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); + + return tidfrg_A(ref_A); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutB_NK() + { + // (N,K) -> (N,K) + auto ref_B = make_layout(make_shape(size<1>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); + // (bthrid,val) -> (N,K) + auto layoutB_TV = thrfrg_B(ref_B); + // (N,K) -> (bthrid,frg) + auto layoutB_NK = right_inverse(layoutB_TV).with_shape(shape(ref_B)); + + // bthrid = (v,n,k) -> thr_idx + auto thrID_B = ThrLayoutVMNK{}(_,Int<0>{},_,_); + + return cute::make_tuple(layoutB_NK, thrID_B); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutB_TV() + { + // (N,K) -> (N,K) + auto ref_B = make_layout(make_shape(size<1>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); + + return tidfrg_B(ref_B); + } +}; + +template +struct ThrMMA : TiledMMA +{ + // Use ThrVMNK and thrfrg rather than thr_idx and tidfrg + // to support swizzled threads partitioning dynamic layouts + ThrVMNK thr_vmnk_; + + CUTE_HOST_DEVICE constexpr + ThrMMA(ThrVMNK const& thr_vmnk) : thr_vmnk_(thr_vmnk) {} + + template + CUTE_HOST_DEVICE constexpr + auto + partition_C(CTensor&& ctensor) const + { + auto thr_tensor = make_tensor(std::forward(ctensor).data(), thrfrg_C(ctensor.layout())); + + auto thr_vmn = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_))); + return thr_tensor(thr_vmn, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_A(ATensor&& atensor) const + { + auto thr_tensor = make_tensor(std::forward(atensor).data(), thrfrg_A(atensor.layout())); + + auto thr_vmk = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<3>(thr_vmnk_))); + return thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_B(BTensor&& btensor) const + { + auto thr_tensor = make_tensor(std::forward(btensor).data(), thrfrg_B(btensor.layout())); + + auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_))); + return thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_C(CTensor&& ctensor) const + { + return make_fragment_C(partition_C(ctensor)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_A(ATensor&& atensor) const + { + return make_fragment_A(partition_A(atensor)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_B(BTensor&& btensor) const + { + return make_fragment_B(partition_B(btensor)); + } +}; + +// +// These tile the MMA_Atom as a whole +// + +template >, + class MMAValLayout = Layout>, + class Permutations = Tile> +CUTE_HOST_DEVICE constexpr +auto +make_tiled_mma(MMA_Atom const&, + MMAThrLayout const& thr_layout = {}, + MMAValLayout const& val_layout = {}, + Permutations const& permutations = {}) +{ + auto thr_layout_mnk = append<3>(thr_layout, Layout<_1>{}); + auto val_layout_mnk = append<3>(val_layout, Layout<_1>{}); + auto permutation_mnk = append<3>(permutations, _); + + return TiledMMA, + decltype(thr_layout_mnk), + decltype(val_layout_mnk), + decltype(permutation_mnk)>{}; +} + +template >, + class MMAValLayout = Layout>, + class Permutations = Tile> +CUTE_HOST_DEVICE constexpr +auto +make_tiled_mma(MMA_Op const&, + MMAThrLayout const& thr_layout = {}, + MMAValLayout const& val_layout = {}, + Permutations const& permutations = {}) +{ + // Attempt to wrap in an MMA_Atom<> and forward + return make_tiled_mma(MMA_Atom{}, thr_layout, val_layout, permutations); +} + +// +// partition_fragment_C -- static context +// + +template +CUTE_HOST_DEVICE constexpr +auto +partition_fragment_C(TiledMMA, Shape_MN shapeMN) +{ + constexpr int R = rank_v; + static_assert(R >= 2, "Must have at least rank-2"); + auto atomMNK = typename TiledMMA::AtomShape_MNK{}; + auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; + + auto V = size<1>(typename TiledMMA::AtomLayoutC_TV{}); + auto M = shape_div(size<0>(shapeMN), size<0>(atomMNK) * size<1>(thrVMNK)); + auto N = shape_div(size<1>(shapeMN), size<1>(atomMNK) * size<2>(thrVMNK)); + auto frg_shape = tuple_cat(make_shape(V,M,N), take<2,R>(shapeMN)); + + return make_tensor::FrgTypeC>(frg_shape); +} + +// partition_fragment_A and partition_fragment_B often depend on the +// layout of A and B and/or the thread_idx that is requesting the partition. +// For these reasons, they should not be used in a static context. +// See TiledMMA::get_slice(thr_idx).partition_fragment_A(tensorA) instead. + +// +// Size +// + +template +CUTE_HOST_DEVICE constexpr +auto +tile_size(TiledMMA const& mma) +{ + return size(typename TiledMMA::TiledShape_MNK{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +size(TiledMMA const& mma) +{ + return size(typename TiledMMA::ThrLayoutVMNK{}); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE +auto +print_latex(TiledMMA const& mma) +{ + auto layout_and_thrid_C = mma.get_layoutC_MN(); + auto layoutC_MN = get<0>(layout_and_thrid_C); + auto thrID_C = get<1>(layout_and_thrid_C); + + auto layout_and_thrid_A = mma.get_layoutA_MK(); + auto layoutA_MK = get<0>(layout_and_thrid_A); + auto thrID_A = get<1>(layout_and_thrid_A); + + auto layout_and_thrid_B = mma.get_layoutB_NK(); + auto layoutB_NK = get<0>(layout_and_thrid_B); + auto thrID_B = get<1>(layout_and_thrid_B); + + print_latex_mma(layoutC_MN, thrID_C, + layoutA_MK, thrID_A, + layoutB_NK, thrID_B); +} + +// EXPERIMENTAL -- Doesn't work with Swizzled Thr TileMMAs... +template +CUTE_HOST_DEVICE +auto +print_latex_2(TiledMMA const& mma) +{ + print_latex_mma(typename TiledMMA::TiledShape_MNK{}, + mma.get_layoutC_TV(), + mma.get_layoutA_TV(), + mma.get_layoutB_TV()); +} + +// MNK MMA Layout to console printer -- 8-value color coded by thread +template +CUTE_HOST_DEVICE +void +print_layout_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx + LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx +{ + CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); + + assert(size<0>(A) == size<0>(C)); + assert(size<0>(B) == size<1>(C)); + assert(size<1>(A) == size<1>(B)); + + int a_width = size<1>(A) * 6 + 4; + + // Print out B (white-shifted) k-by-n + for (int k = 0; k < size<1>(B); ++k) { + // Header + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("+-----"); + printf("+\n"); + // Values + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("|T%02dV%1d", int(TB(B(n,k) % size(TB))), int(B(n,k) / size(TB))); + printf("|\n"); + } + // Footer + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("+-----"); + printf("+\n\n"); + + // Print out A m-by-k and C m-by-n + for (int m = 0; m < size<0>(A); ++m) { + // Header + for (int k = 0; k < size<1>(A); ++k) printf("+-----"); + printf("+ "); + for (int n = 0; n < size<1>(C); ++n) printf("+-----"); + printf("+\n"); + // Values + for (int k = 0; k < size<1>(A); ++k) printf("|T%02dV%1d", int(TA(A(m,k) % size(TA))), int(A(m,k) / size(TA))); + printf("| "); + for (int n = 0; n < size<1>(C); ++n) printf("|T%02dV%1d", int(TC(C(m,n) % size(TC))), int(C(m,n) / size(TC))); + printf("|\n"); + } + // Footer + for (int k = 0; k < size<1>(A); ++k) printf("+-----"); + printf("+ "); + for (int n = 0; n < size<1>(C); ++n) printf("+-----"); + printf("+\n"); +} + +// MNK MMA Layout to Latex TIKZ -- 8-value color coded by thread +template +CUTE_HOST_DEVICE +void +print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx + LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx +{ + CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); + + assert(size<0>(A) == size<0>(C)); + assert(size<0>(B) == size<1>(C)); + assert(size<1>(A) == size<1>(B)); + + char const* latex_header = + "\\documentclass{standalone}\n" + "\\usepackage{tikz}\n" + "\\usetikzlibrary{external}\n" + "\\tikzexternalize\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; + char const* latex_footer = + "\\end{tikzpicture}\n" + "\\end{document}\n"; + + char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", + "{rgb,255:red,175;green,255;blue,175}", + "{rgb,255:red,255;green,255;blue,175}", + "{rgb,255:red,255;green,175;blue,175}", + "{rgb,255:red,210;green,210;blue,255}", + "{rgb,255:red,210;green,255;blue,210}", + "{rgb,255:red,255;green,255;blue,210}", + "{rgb,255:red,255;green,210;blue,210}"}; + + // Header + printf("%% LayoutC: "); print(C); printf("\n"); + printf("%% ThrIDC : "); print(TC); printf("\n"); + printf("%% LayoutA: "); print(A); printf("\n"); + printf("%% ThrIDA : "); print(TA); printf("\n"); + printf("%% LayoutB: "); print(B); printf("\n"); + printf("%% ThrIDB : "); print(TB); printf("\n\n"); + + printf(latex_header); + + // C starting at 0,0 + for (int m = 0; m < size<0>(C); ++m) { + for (int n = 0; n < size<1>(C); ++n) { + int thrid = C(m,n) % size(TC); + int val_idx = C(m,n) / size(TC); + int thr_idx = TC(thrid); + + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[thr_idx % 8], + m, n, + thr_idx, val_idx); + } + } + + // A starting at 0,-size<1>(A)-1 + for (int m = 0; m < size<0>(A); ++m) { + for (int k = 0; k < size<1>(A); ++k) { + int thrid = A(m,k) % size(TA); + int val_idx = A(m,k) / size(TA); + int thr_idx = TA(thrid); + + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[thr_idx % 8], + m, k-1-size<1>(A), + thr_idx, val_idx); + } + } + + // B starting at -size<1>(B)-1,0 + for (int n = 0; n < size<0>(B); ++n) { + for (int k = 0; k < size<1>(B); ++k) { + int thrid = B(n,k) % size(TB); + int val_idx = B(n,k) / size(TB); + int thr_idx = TB(thrid); + + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[thr_idx % 8], + k-1-size<1>(B), n, + thr_idx, val_idx); + } + } + + // A labels + for (int m = 0, k = -1; m < size<0>(A); ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), m); + } + for (int k = 0, m = -1; k < size<1>(A); ++k) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), k); + } + // B labels + for (int n = 0, k = -1; n < size<0>(B); ++n) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, n); + } + for (int k = 0, n = -1; k < size<1>(B); ++k) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, k); + } + + // Footer + printf(latex_footer); +} + +// ThrVal MMA Layout to Latex TIKZ -- 8-value color coded by thread +template +CUTE_HOST_DEVICE +void +print_latex_mma(Shape_MNK const& shape_mnk, + LayoutC const& C, // (thr_idx,vid) -> (m,n) + LayoutA const& A, // (thr_idx,vid) -> (m,k) + LayoutB const& B) // (thr_idx,vid) -> (n,k) +{ + CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); + + char const* latex_header = + "\\documentclass{standalone}\n" + "\\usepackage{tikz}\n" + "\\usetikzlibrary{external}\n" + "\\tikzexternalize\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; + char const* latex_footer = + "\\end{tikzpicture}\n" + "\\end{document}\n"; + + char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", + "{rgb,255:red,175;green,255;blue,175}", + "{rgb,255:red,255;green,255;blue,175}", + "{rgb,255:red,255;green,175;blue,175}", + "{rgb,255:red,210;green,210;blue,255}", + "{rgb,255:red,210;green,255;blue,210}", + "{rgb,255:red,255;green,255;blue,210}", + "{rgb,255:red,255;green,210;blue,210}"}; + + // Header + printf("%% Shape_MNK: "); print(shape_mnk); printf("\n"); + printf("%% LayoutC : "); print(C); printf("\n"); + printf("%% LayoutA : "); print(A); printf("\n"); + printf("%% LayoutB : "); print(B); printf("\n\n"); + + printf(latex_header); + + int M = size<0>(shape_mnk); + int N = size<1>(shape_mnk); + int K = size<2>(shape_mnk); + + // C starting at 0,0 + bool c_filled[M][N] = {}; + for (int t = 0; t < size<0>(C); ++t) { + for (int v = 0; v < size<1>(C); ++v) { + int m = C(t,v) % M; + int n = C(t,v) / M; + + if (not c_filled[m][n]) { + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[t % 8], + m, n, + t, v); + c_filled[m][n] = true; + } + } + } + + // A starting at 0,-size<1>(A)-1 + bool a_filled[M][K] = {}; + for (int t = 0; t < size<0>(A); ++t) { + for (int v = 0; v < size<1>(A); ++v) { + int m = A(t,v) % M; + int k = A(t,v) / M; + + if (not a_filled[m][k]) { + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[t % 8], + m, k - 1 - K, + t, v); + a_filled[m][k] = true; + } + } + } + + // B starting at -size<1>(B)-1,0 + bool b_filled[N][K] = {}; + for (int t = 0; t < size<0>(B); ++t) { + for (int v = 0; v < size<1>(B); ++v) { + int n = B(t,v) % N; + int k = B(t,v) / N; + + if (not b_filled[n][k]) { + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[t % 8], + k - 1 - K, n, + t, v); + b_filled[n][k] = true; + } + } + } + + // A labels + for (int m = 0, k = -1; m < M; ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k - 1 - K, m); + } + for (int k = 0, m = -1; k < K; ++k) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k - 1 - K, k); + } + // B labels + for (int n = 0, k = -1; n < N; ++n) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k - 1 - K, n, n); + } + for (int k = 0, n = -1; k < K; ++k) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k - 1 - K, n, k); + } + + // Footer + printf(latex_footer); +} + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/mma_traits.hpp b/include/cute/atom/mma_traits.hpp new file mode 100644 index 0000000000..a8c3323a36 --- /dev/null +++ b/include/cute/atom/mma_traits.hpp @@ -0,0 +1,70 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +template +struct MMA_Traits +{ + static_assert(sizeof(MMAOperation) == 0, "MMA_Traits not implemented for this MMA_Operation."); +}; + +template +struct MMA_Traits> +{ + using ElementDVal = D; + using ElementAVal = A; + using ElementBVal = B; + using ElementCVal = C; + + // Logical shape of the MMA + using Shape_MNK = Shape<_1,_1,_1>; + + // Logical thread id (tid) -> tidx + using ThrID = Layout<_1>; + + // (Logical thread id (tid), Logical value id (vid)) -> coord + + // (tid,vid) -> (m,k) + using ALayout = Layout>; + // (tid,vid) -> (n,k) + using BLayout = Layout>; + // (tid,vid) -> (m,n) + using CLayout = Layout>; +}; + +} // namespace cute diff --git a/include/cute/atom/mma_traits_sm61.hpp b/include/cute/atom/mma_traits_sm61.hpp new file mode 100644 index 0000000000..85d4e98787 --- /dev/null +++ b/include/cute/atom/mma_traits_sm61.hpp @@ -0,0 +1,73 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template <> +struct MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using Shape_MNK = Shape<_1,_1,_4>; + using ThrID = Layout<_1>; + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = int16_t; + using ElementBVal = int16_t; + using ElementCVal = int32_t; + + using Shape_MNK = Shape<_1,_1,_2>; + using ThrID = Layout<_1>; + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +} // namespace cute diff --git a/include/cute/atom/mma_traits_sm70.hpp b/include/cute/atom/mma_traits_sm70.hpp new file mode 100644 index 0000000000..79430350ce --- /dev/null +++ b/include/cute/atom/mma_traits_sm70.hpp @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +namespace { + +// Logical thread id to thread idx (quadpair) +using SM70_QuadPair = Layout, + Stride<_1,_16>>; +// (T8,V4) -> (M8,K4) +using SM70_8x4_Row = Layout, + Stride<_1,_8>>; +// (T8,V4) -> (M8,K4) +using SM70_8x4_Col = Layout,_4>, + Stride,_1>>; +// (T8,V8) -> (M8,N8) +using SM70_8x8_16b = Layout, + Stride<_1,_8>>; +// (T8,V8) -> (M8,N8) +using SM70_8x8_32b = Layout,Shape <_2,_2, _2>>, + Stride,Stride<_8,_2,_32>>>; + +} + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// +} // namespace cute diff --git a/include/cute/atom/mma_traits_sm75.hpp b/include/cute/atom/mma_traits_sm75.hpp new file mode 100644 index 0000000000..405e871fd2 --- /dev/null +++ b/include/cute/atom/mma_traits_sm75.hpp @@ -0,0 +1,81 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template <> +struct MMA_Traits +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_1>>>; + using BLayout = Layout,_2>, + Stride,_8>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_1>>>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using Shape_MNK = Shape<_8,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,_4>, + Stride,_8>>; + using BLayout = Layout,_4>, + Stride,_8>>; + using CLayout = Layout,_2>, + Stride,_8>>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cute diff --git a/include/cute/atom/mma_traits_sm80.hpp b/include/cute/atom/mma_traits_sm80.hpp new file mode 100644 index 0000000000..6636b7aaa5 --- /dev/null +++ b/include/cute/atom/mma_traits_sm80.hpp @@ -0,0 +1,446 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +#include + +#include + +namespace cute +{ + +namespace { + +// (T32,V1) -> (M8,N8) +using SM80_8x4 = Layout,_1>, + Stride,_0>>; +// (T32,V2) -> (M8,N8) +using SM80_8x8_Row = Layout,_2>, + Stride,_8>>; +// (T32,V4) -> (M8,N16) +using SM80_8x16_Row = Layout,_4>, + Stride,_8>>; +// (T32,V4) -> (M16,N8) +using SM80_16x8_Row = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + +} + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp16 = fp16 * fp16 + fp16 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = SM80_16x8_Row; + using BLayout = SM80_8x8_Row; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _2,_2, _2>>, + Stride,Stride<_16,_8,_128>>>; + using BLayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using CLayout = SM80_16x8_Row; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = fp16 * fp16 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = bf16 * bf16 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = tf32 * tf32 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = float; + using ElementAVal = cutlass::tfloat32_t; + using ElementBVal = cutlass::tfloat32_t; + using ElementCVal = float; + + using Shape_MNK = Shape<_16,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = Layout,_2>, + Stride,_8>>; + using BLayout = SM80_8x4; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits +{ + using ElementDVal = float; + using ElementAVal = cutlass::tfloat32_t; + using ElementBVal = cutlass::tfloat32_t; + using ElementCVal = float; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _2>, + Stride,_32>>; + using CLayout = SM80_16x8_Row; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = double; + using ElementAVal = double; + using ElementBVal = double; + using ElementCVal = double; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = SM80_8x4; + using BLayout = SM80_8x4; + using CLayout = SM80_8x8_Row; +}; + +// Custom complex fp64 MMA composed of 4 fp64 MMAs -- same layouts +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = complex; + using ElementAVal = complex; + using ElementBVal = complex; + using ElementCVal = complex; +}; + +// Custom complex fp64 MMA composed of 3 fp64 MMAs -- same layouts +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; + using ElementAVal = complex; + using ElementBVal = complex; + using ElementCVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s8 * s8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using Shape_MNK = Shape<_8,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = SM80_8x16_Row; + using BLayout = SM80_8x16_Row; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _4,_2>>, + Stride,Stride<_16,_8>>>; + using BLayout = SM80_8x16_Row; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using Shape_MNK = Shape<_16,_8,_32>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_16,_8,_256>>>; + using BLayout = Layout, Shape <_4, _2>>, + Stride, Stride<_8,_128>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s8 * u8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u8 * s8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u8 * u8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = b1 ^ b1 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = int32_t; + using ElementAVal = cute::uint1b_t; + using ElementBVal = cute::uint1b_t; + using ElementCVal = int32_t; + + using Shape_MNK = Shape<_16,_8,_256>; + using ThrID = Layout<_32>; + using ALayout = Layout>, + Stride<_64,Stride<_64,_16,_8,_2048>>>; + using BLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + using CLayout = SM80_16x8_Row; +}; +} // end namespace cute diff --git a/include/cute/atom/mma_traits_sm90.hpp b/include/cute/atom/mma_traits_sm90.hpp new file mode 100644 index 0000000000..b7a12b98f4 --- /dev/null +++ b/include/cute/atom/mma_traits_sm90.hpp @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute { + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ElementDVal = double; + using ElementAVal = double; + using ElementBVal = double; + using ElementCVal = double; + + using Shape_MNK = Shape<_16,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = Layout,_2>, + Stride,_8>>; + using BLayout = Layout,_1>, + Stride,_0>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +template <> +struct MMA_Traits +{ + using ElementDVal = double; + using ElementAVal = double; + using ElementBVal = double; + using ElementCVal = double; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _2>, + Stride,_32>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +template <> +struct MMA_Traits +{ + using ElementDVal = double; + using ElementAVal = double; + using ElementBVal = double; + using ElementCVal = double; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _4>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _4>, + Stride,_32>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +/////////////////////////////////////////////////////////////////////////////////// +//////////////////////// cfp64 = cfp64 * cfp64 + cfp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = complex; + using ElementAVal = complex; + using ElementBVal = complex; + using ElementCVal = complex; +}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = complex; + using ElementAVal = complex; + using ElementBVal = complex; + using ElementCVal = complex; +}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ElementDVal = complex; + using ElementAVal = complex; + using ElementBVal = complex; + using ElementCVal = complex; +}; + +} // end namespace cute diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp new file mode 100644 index 0000000000..d390dafc58 --- /dev/null +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -0,0 +1,2975 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute { + +namespace GMMA { + +/////////////////////////////////////////// +// Common layouts for GMMA Shared Memory // +/////////////////////////////////////////// + +// M|N-major GMMA layouts in units of bits +using Layout_MN_INTER_Atom_Bits = Layout,Stride<_1,_128>>; +using Layout_MN_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _256>>>; +using Layout_MN_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _512>>>; +using Layout_MN_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1,_1024>>>; + +// K-major GMMA layouts in units of bits +using Layout_K_INTER_Atom_Bits = Layout,Stride<_128,_1>>; +using Layout_K_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _256,_1>>>; +using Layout_K_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _512,_1>>>; +using Layout_K_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1024,_1>>>; + +// M|N-major layouts in units of Type +template +using Layout_MN_INTER_Atom = decltype(upcast::value>(Layout_MN_INTER_Atom_Bits{})); +template +using Layout_MN_SW32_Atom = decltype(upcast::value>(Layout_MN_SW32_Atom_Bits{})); +template +using Layout_MN_SW64_Atom = decltype(upcast::value>(Layout_MN_SW64_Atom_Bits{})); +template +using Layout_MN_SW128_Atom = decltype(upcast::value>(Layout_MN_SW128_Atom_Bits{})); + +// K-major layouts in units of Type +template +using Layout_K_INTER_Atom = decltype(upcast::value>(Layout_K_INTER_Atom_Bits{})); +template +using Layout_K_SW32_Atom = decltype(upcast::value>(Layout_K_SW32_Atom_Bits{})); +template +using Layout_K_SW64_Atom = decltype(upcast::value>(Layout_K_SW64_Atom_Bits{})); +template +using Layout_K_SW128_Atom = decltype(upcast::value>(Layout_K_SW128_Atom_Bits{})); + +// With GMMA::Major param +template +using Layout_INTER_Atom = typename std::conditional, + Layout_K_INTER_Atom>::type; +template +using Layout_SW32_Atom = typename std::conditional, + Layout_K_SW32_Atom>::type; +template +using Layout_SW64_Atom = typename std::conditional, + Layout_K_SW64_Atom>::type; +template +using Layout_SW128_Atom = typename std::conditional, + Layout_K_SW128_Atom>::type; + +// Helper for GMMA smem selection that considers a tensor TileShape: +// (BLK_MN, BLK_K) +// or hierarchically +// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) +// and returns the largest GMMA::Layout that fits BLK_MN0 and BLK_K0 +template +CUTE_HOST_DEVICE constexpr +auto +smem_selector() +{ + auto BLK_MN0 = size<0>(BLK_MN{}); + auto BLK_K0 = size<0>(BLK_K{}); + + static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); + static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); + + + if constexpr (major == GMMA::Major::MN) { + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom{}) == 0) { + return GMMA::Layout_MN_SW64_Atom{}; + } else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { + return GMMA::Layout_MN_SW32_Atom{}; + } else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { + return GMMA::Layout_MN_INTER_Atom{}; + } else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); + } + } else if constexpr (major == GMMA::Major::K) { + if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom{}) == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom{}) == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom{}) == 0) { + return GMMA::Layout_K_SW32_Atom{}; + } else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0) { + return GMMA::Layout_K_INTER_Atom{}; + } else { + static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0, + "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})"); + } + } +} + +// +// Tensor to LayoutType utility +// + +// smem_ptr_swizzle LayoutType +template +CUTE_HOST_DEVICE constexpr +LayoutType +layout_type(Tensor>>, + Layout> const&) +{ + static_assert(M == 4, "Unsupported layout swizzle"); + static_assert(0 <= B && B <= 3, "Unsupported layout swizzle"); + static_assert(S == 3, "Unsupported layout swizzle"); + + switch (B) { + case 0: return LayoutType::INTERLEAVE; + case 1: return LayoutType::B32; + case 2: return LayoutType::B64; + case 3: return LayoutType::B128; + } + return LayoutType::INTERLEAVE; // ERROR +} + +// smem_ptr non-swizzled LayoutType +template +CUTE_HOST_DEVICE constexpr +LayoutType +layout_type(Tensor>, + Layout> const&) +{ + return LayoutType::INTERLEAVE; +} + +/////////////////////////////////////////////////////////////////////////////// +// Construction method for GMMA Descriptors +/////////////////////////////////////////////////////////////////////////////// + +/** +* /////////////////////////////// +* // make_gmma_desc // +* /////////////////////////////// +* Each GmmaDescriptor Major-MN describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((T,1,m),(8,k)):((1,T,SBO),(1T,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((T,2,m),(8,k)):((1,T,LBO),(2T,SBO)) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((T,4,m),(8,k)):((1,T,LBO),(4T,SBO)) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((T,8,m),(8,k)):((1,T,LBO),(8T,SBO)) +* +* where +* T : sizeof(uint128_t) / sizeof(value_type) +* m : integer in [1,16] corresponding to GMMA shape +* k : integer in [1,32] corresponding to GMMA shape +* SBO: stride byte offset +* LBO: leading byte offset +* +* See GMMA::Layout_MN_XXX_Atom for building canonical GmmaDescriptor Major-MN layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_MN_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_gmma_desc for appropriate value_type. +* +* ////////////////////////////// +* // make_gmma_desc // +* ////////////////////////////// +* Each GmmaDescriptor Major-K describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,m),(T,2)):((1T,SBO),(1,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,m),(T,2)):((2T,SBO),(1, T )) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,m),(T,2)):((4T,SBO),(1, T )) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,m),(T,2)):((8T,SBO),(1, T )) +* +* See GMMA::Layout_K_XXX_Atom for building canonical GmmaDescriptor Major-K layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_K_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_gmma_desc for appropriate value_type. +*/ +template +CUTE_HOST_DEVICE constexpr +GmmaDescriptor +make_gmma_desc(Tensor const& tensor) +{ + static_assert(is_smem::value, "GMMA Descriptors can only be constructed on smem."); + static_assert(TLayout::rank == 2, "GMMA Descriptors can only be constructed on rank-2 tensors."); + using value_type = typename TEngine::value_type; + + Tensor u128_tensor = recast(tensor); + + // Result + GmmaDescriptor desc; + + // Layout type + constexpr GMMA::LayoutType LAYOUT_TYPE = GMMA::layout_type(u128_tensor); + desc.layout_type_ = uint8_t(LAYOUT_TYPE); + + // Start address (4LSB not included) + uint32_t start_address = cast_smem_ptr_to_uint(u128_tensor.data().get()); + desc.start_address_ = start_address >> 4; + + constexpr uint8_t base_offset = 0; + desc.base_offset_ = base_offset; + + // LayoutType meta + constexpr int W = LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE ? 1 : + LAYOUT_TYPE == GMMA::LayoutType::B32 ? 2 : + LAYOUT_TYPE == GMMA::LayoutType::B64 ? 4 : + LAYOUT_TYPE == GMMA::LayoutType::B128 ? 8 : -1; + + if constexpr (MajorMode == GMMA::Major::MN) + { + /* In units of uint128_t, each GmmaDescriptor Major-MN describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((1,n),(8,k)):((X,SBO),(1,LBO)) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((2,n),(8,k)):((1,LBO),(2,SBO)) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((4,n),(8,k)):((1,LBO),(4,SBO)) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),(8,k)):((1,LBO),(8,SBO)) + */ + static_assert(size<1>(u128_tensor) == Int<(256 / cute::sizeof_bits::value)>{}, // K size + "Not a canonical GMMA_MN Layout: Expected K-size 256/sizeof_bits."); + + // Construct the canonical GMMA T Layout with shape ((W,n),(8,2)) + Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout,_1>{}, Layout,_1>{})); + + // Check ranks of canonical + CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_MN Layout: No flat offset mode"); + CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_MN Layout: No flat offset mode"); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE ? stride<0,0>(canonical_layout) : 1; + static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_MN Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = W; + static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_MN Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + constexpr uint32_t stride_11 = stride<1,1>(canonical_layout); + + desc.stride_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_01 : stride_11; + desc.leading_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_11 : stride_01; + } + else if constexpr (MajorMode == GMMA::Major::K) + { + /* In units of uint128_t, each GmmaDescriptor Major-K describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,n),2):((1,SBO),LBO) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,n),2):((2,SBO),1) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,n),2):((4,SBO),1) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),2):((8,SBO),1) + */ + CUTE_STATIC_ASSERT_V(size<0>(u128_tensor) % Int<8>{} == Int<0>{}, // N|M size + "Not a canonical GMMA_K Layout: Expected MN-size multiple of 8."); + CUTE_STATIC_ASSERT_V(size<1>(u128_tensor) == Int<2>{}, // K size + "Not a canonical GMMA_K Layout: Expected K-size 2 (in units of uint128_t)."); + + // Construct the canonical GMMA N Layout with shape ((8,n),(2,1)) + Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout<_8,_1>{}, Layout<_2,_1>{})); + + // Check ranks of canonical + CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_K Layout: No flat offset mode"); + CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_K Layout: No flat offset mode"); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = W; + static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_K Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride<1,0>(canonical_layout) : 1; + static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_K Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + + desc.stride_byte_offset_ = stride_01; + desc.leading_byte_offset_ = stride_10; + } else { + static_assert(MajorMode != GMMA::Major::MN && MajorMode != GMMA::Major::K, "Unrecognized MajorMode!"); + } + +#if 0 + // DEBUG and SANITY + assert((start_address & 0b0000001111) == 0); // Must be 16B aligned (4LSB are 0) no negotiation + assert((start_address & 0b1110000000) == 0); // Assert base_offset is 0, generalize later + if (thread0()) { + print("smem_desc input tensor: "); print(tensor.data()); print(" o "); print(tensor.layout()); print("\n"); + print("smem_desc uint128_t tensor: "); print(u128_tensor.data()); print(" o "); print(u128_tensor.layout()); print("\n"); + //print(" desc canonical layout: "); print(canonical_layout); print("\n"); + print(desc); + } +#endif + + return desc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Higher level GMMA Descriptor utilities +/////////////////////////////////////////////////////////////////////////////// + +struct gmma_descriptor_iterator +{ + GmmaDescriptor desc_; + + // Dereference returns the GmmaDescriptor + CUTE_HOST_DEVICE constexpr + GmmaDescriptor const& operator*() const { return desc_; } + + // Advance and return a new GmmaDescriptor + template + CUTE_HOST_DEVICE constexpr + GmmaDescriptor operator[](Index const& i) const { return *(*this + i); } + + // Return an advanced iterator + template + CUTE_HOST_DEVICE constexpr + gmma_descriptor_iterator operator+(Index const& offset) const + { + // offset is in the units of uint128_t (4LSB of start_address not included) + + //GmmaDescriptor desc = desc_; + //desc.start_address_ += uint16_t(offset); + //desc.reg32_[0] += uint16_t(offset); // Generates better asm than adding to the bitfield + + // May need to update base_offset if swizzle alignment isn't guaranteed + //desc.base_offset_ = 0; + //assert((desc.start_address_ & 0b111000) == 0); // Assert base_offset is 0, generalize later + + //return {desc}; + + // The above seems to not work for some reason... + return {desc_ + uint64_t(offset)}; + } +}; + +template +struct smem_desc : gmma_descriptor_iterator {}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_gmma_desc_fragment(Tensor const& t) +{ + // Cast to a uint128_t tensor for GMMA Desc iteration + return make_tensor(gmma_descriptor_iterator{make_gmma_desc(tensor<0>(t))}, + recast(t).layout()); +} + +// Recast a tensor to a tensor of gmma_descriptor_iterator +template +CUTE_HOST_DEVICE constexpr +auto +recast(Tensor&& tensor, type_list>) +{ + return make_gmma_desc_fragment(tensor); +} + +// Recast a gmma_descriptor_iterator Tensor to uint64_t, it's RegType +template +CUTE_HOST_DEVICE constexpr +auto +recast(Tensor,TLayout> const& tensor, type_list) +{ + static_assert(std::is_same::value, "Can only cast descriptors to uint64_t."); + return make_tensor(tensor.data(), Layout<_1,_0>{}); +} + +} // end namespace GMMA + +// Fence between the async destination accumulators of GMMA & source for their dependent use +template +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(Tensor& frg) { + CUTE_STATIC_ASSERT(is_static::value); + if constexpr (std::is_same_v) { + auto f32_frg = recast(frg); + CUTE_UNROLL + for (int i = 0; i < size(f32_frg); ++i) { + warpgroup_fence_operand(f32_frg(i)); + } + } + else { + CUTE_STATIC_ASSERT(is_rmem::value); + auto u32_frg = recast(frg); + CUTE_UNROLL + for (int i = 0; i < size(u32_frg); ++i) { + warpgroup_fence_operand(u32_frg(i)); + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// MMA_TRAITS /////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace GMMA { + +// Accumulator layouts +using CLayout_64x8 = Layout,Shape < _2,_2>>, + Stride,Stride<_64,_8>>>; + +using CLayout_64x16 = Layout,Shape < _2,_2, _2>>, + Stride,Stride<_64,_8,_512>>>; + +using CLayout_64x32 = Layout,Shape < _2,_2, _4>>, + Stride,Stride<_64,_8,_512>>>; + +using CLayout_64x64 = Layout,Shape < _2,_2, _8>>, + Stride,Stride<_64,_8,_512>>>; + +using CLayout_64x96 = Layout,Shape < _2,_2, _12>>, + Stride,Stride<_64,_8,_512>>>; + +using CLayout_64x128 = Layout,Shape < _2,_2, _16>>, + Stride,Stride<_64,_8,_512>>>; + +using CLayout_64x192 = Layout,Shape < _2,_2, _24>>, + Stride,Stride<_64,_8,_512>>>; + +using CLayout_64x256 = Layout,Shape < _2,_2, _32>>, + Stride,Stride<_64,_8,_512>>>; + +// Register source layout for 32-bit value types +using ALayout_64x8 = Layout,Shape < _2, _2>>, + Stride,Stride< _8,_256>>>; + +// Register source layout for 16-bit value types +using ALayout_64x16 = CLayout_64x16; + +// Register source layout for 8-bit value types +using ALayout_64x32 = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_64,_8,_1024>>>; + +// Shared memory source layouts for any value type +template +using ABLayout = Layout,Int>>, + Stride< _0,Stride< _1,Int>>>; + +} // namespace GMMA + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = bfloat16_t; + using ElementBVal = bfloat16_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 8, 8>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 8, 8>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 16, 8>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 16, 8>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 32, 8>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 32, 8>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 64, 8>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 64, 8>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 96, 8>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 96, 8>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<128, 8>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<128, 8>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<192, 8>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<192, 8>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<256, 8>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = tfloat32_t; + using ElementBVal = tfloat32_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<256, 8>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = int8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = int8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = int32_t; + using ElementAVal = uint8_t; + using ElementBVal = uint8_t; + using ElementCVal = int32_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/config.hpp b/include/cute/config.hpp new file mode 100644 index 0000000000..b2f4de8363 --- /dev/null +++ b/include/cute/config.hpp @@ -0,0 +1,121 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) +# define CUTE_HOST_DEVICE __forceinline__ __host__ __device__ +# define CUTE_DEVICE __forceinline__ __device__ +# define CUTE_HOST __forceinline__ __host__ +#else +# define CUTE_HOST_DEVICE inline +# define CUTE_DEVICE inline +# define CUTE_HOST inline +#endif // CUTE_HOST_DEVICE, CUTE_DEVICE + +#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) +# define CUTE_UNROLL #pragma unroll +# define CUTE_NO_UNROLL #pragma unroll 1 +#else +# define CUTE_UNROLL +# define CUTE_NO_UNROLL +#endif // CUTE_UNROLL + +#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) +# define CUTE_INLINE_CONSTANT static const __device__ +#else +# define CUTE_INLINE_CONSTANT static constexpr +#endif + +// Some versions of GCC < 11 have trouble deducing that a +// function with "auto" return type and all of its returns in an "if +// constexpr ... else" statement must actually return. Thus, GCC +// emits spurious "missing return statement" build warnings. +// Developers can suppress these warnings by using the +// CUTE_GCC_UNREACHABLE macro, which must be followed by a semicolon. +// It's harmless to use the macro for other GCC versions or other +// compilers, but it has no effect. +#if ! defined(CUTE_GCC_UNREACHABLE) +# if defined(__GNUC__) && __GNUC__ < 11 + // GCC 10, but not 7.5, 9.4.0, or 11, issues "missing return + // statement" warnings without this little bit of help. +# define CUTE_GCC_UNREACHABLE __builtin_unreachable() +# else +# define CUTE_GCC_UNREACHABLE +# endif +#endif + +// +// Assertion helpers +// + +#include + +#define CUTE_STATIC_ASSERT static_assert +#define CUTE_STATIC_ASSERT_V(x,...) static_assert(decltype(x)::value, ##__VA_ARGS__) + +#if defined(__CUDA_ARCH__) +# define CUTE_RUNTIME_ASSERT(x) asm volatile ("brkpt;\n" ::: "memory") +#else +# define CUTE_RUNTIME_ASSERT(x) assert(0 && x) +#endif + +// +// IO +// + +#include +#include +#include + +// +// Support +// + +#include + +// +// Basic types +// + +#include +#include +#include +#include +#include +#include +#include + +// +// Debugging utilities +// + +#include +#include diff --git a/include/cute/container/alignment.hpp b/include/cute/container/alignment.hpp new file mode 100644 index 0000000000..49101fa7a9 --- /dev/null +++ b/include/cute/container/alignment.hpp @@ -0,0 +1,70 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// Test if a pointer is aligned to N bytes +template +CUTE_HOST_DEVICE constexpr +bool +is_byte_aligned(void const* const ptr) +{ + static_assert(N > 0 && (N & (N - 1)) == 0, "N must be a power of 2 in alignment check"); + return (reinterpret_cast(ptr) & (N-1)) == 0; +} + +#if defined(__CUDACC__) +# define CUTE_ALIGNAS(n) __align__(n) +#else +# define CUTE_ALIGNAS(n) alignas(n) +#endif + +template +struct aligned_struct {}; + +template <> struct CUTE_ALIGNAS( 1) aligned_struct< 1> {}; +template <> struct CUTE_ALIGNAS( 2) aligned_struct< 2> {}; +template <> struct CUTE_ALIGNAS( 4) aligned_struct< 4> {}; +template <> struct CUTE_ALIGNAS( 8) aligned_struct< 8> {}; +template <> struct CUTE_ALIGNAS( 16) aligned_struct< 16> {}; +template <> struct CUTE_ALIGNAS( 32) aligned_struct< 32> {}; +template <> struct CUTE_ALIGNAS( 64) aligned_struct< 64> {}; +template <> struct CUTE_ALIGNAS(128) aligned_struct<128> {}; +template <> struct CUTE_ALIGNAS(256) aligned_struct<256> {}; + +} // end namespace cute diff --git a/include/cute/container/array.hpp b/include/cute/container/array.hpp new file mode 100644 index 0000000000..571ac0897c --- /dev/null +++ b/include/cute/container/array.hpp @@ -0,0 +1,282 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template +struct array +{ + using value_type = T; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + using iterator = pointer; + using const_iterator = const_pointer; + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference front() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + reference back() + { + // return *rbegin(); + return operator[](N-1); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const + { + // return *rbegin(); + return operator[](N-1); + } + + CUTE_HOST_DEVICE constexpr + T* data() + { + return __elems_; + } + + CUTE_HOST_DEVICE constexpr + T const* data() const + { + return __elems_; + } + + CUTE_HOST_DEVICE constexpr + iterator begin() + { + return data(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const + { + return data(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() + { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const + { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + iterator end() + { + return data() + size(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const + { + return data() + size(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() + { + return end(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const + { + return end(); + } + + CUTE_HOST_DEVICE constexpr + bool empty() const + { + return size() == 0; + } + + CUTE_HOST_DEVICE constexpr + size_type size() const + { + return N; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const + { + return size(); + } + + CUTE_HOST_DEVICE constexpr + void fill(const T& value) + { + for (auto& e : *this) { + e = value; + } + } + + CUTE_HOST_DEVICE constexpr + void clear() + { + fill(T(0)); + } + + CUTE_HOST_DEVICE constexpr + void swap(array& other) + { + using std::swap; + for (size_type i = 0; i < size(); ++i) { + swap((*this)[i], other[i]); + } + } + + value_type __elems_[N > 0 ? N : 1]; +}; + + +template +CUTE_HOST_DEVICE constexpr +bool operator==(array const& lhs, array const& rhs) +{ + for (std::size_t i = 0; i < N; ++i) { + if (lhs[i] != rhs[i]) { + return false; + } + } + return true; +} + +template +CUTE_HOST_DEVICE constexpr +void clear(array& a) +{ + a.fill(T(0)); +} + +template +CUTE_HOST_DEVICE constexpr +void fill(array& a, T const& value) +{ + a.fill(value); +} + +template +CUTE_HOST_DEVICE constexpr +void swap(array& a, array& b) +{ + a.swap(b); +} + +} // end cute + + +// +// Specialize tuple-related functionality for cute::array +// + +#include + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +T& get(array& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T const& get(array const& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T&& get(array&& a) +{ + static_assert(I < N, "Index out of range"); + return std::move(a[I]); +} + +} // end namespace cute + +namespace std +{ + +template +struct tuple_size> + : std::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end std diff --git a/include/cute/container/array_aligned.hpp b/include/cute/container/array_aligned.hpp new file mode 100644 index 0000000000..b1b357278d --- /dev/null +++ b/include/cute/container/array_aligned.hpp @@ -0,0 +1,276 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include + +namespace cute +{ + +template +struct array_aligned + : public aligned_struct +{ + /// Make sure the Alignment makes sense wrt the size of elements. + static_assert(Alignment == 16 || Alignment >= sizeof(T), "Alignment is too small"); + /// Alignment must be a power of two + static_assert(has_single_bit(Alignment), "Alignment must be a power of two"); + + using value_type = T; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + using iterator = pointer; + using const_iterator = const_pointer; + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference front() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + reference back() + { + // return *rbegin(); + return operator[](N-1); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const + { + // return *rbegin(); + return operator[](N-1); + } + + CUTE_HOST_DEVICE constexpr + T* data() + { + return reinterpret_cast(storage); + } + + CUTE_HOST_DEVICE constexpr + T const* data() const + { + return reinterpret_cast(storage); + } + + CUTE_HOST_DEVICE constexpr + iterator begin() + { + return data(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const + { + return data(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() + { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const + { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + iterator end() + { + return data() + size(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const + { + return data() + size(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() + { + return end(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const + { + return end(); + } + + CUTE_HOST_DEVICE constexpr + bool empty() const + { + return size() == 0; + } + + CUTE_HOST_DEVICE constexpr + size_type size() const + { + return N; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const + { + return size(); + } + + CUTE_HOST_DEVICE constexpr + void fill(T const& value) + { + for (auto& e : *this) { + e = value; + } + } + + CUTE_HOST_DEVICE constexpr + void clear() + { + fill(T(0)); + } + + // Not private, we want trivial type + //private: + + /// Storage type to use for Elements + using StorageType = typename uint_byte(Alignment)>::type; + + /// Ensure that there's enough storage for all elements + static_assert(sizeof(StorageType) <= Alignment, "StorageType is too big for given alignment"); + + /// Number of elements in the storage + static constexpr std::size_t storageN = (sizeof(T)*N + sizeof(StorageType) - 1) / sizeof(StorageType); + + /// The storage. + StorageType storage[storageN > 0 ? storageN : 1]; +}; + +// +// Operators +// + +template +CUTE_HOST_DEVICE constexpr +void clear(array_aligned& a) +{ + a.clear(); +} + +template +CUTE_HOST_DEVICE constexpr +void fill(array_aligned& a, T const& value) +{ + a.fill(value); +} + +} // end namespace cute + +// +// Specialize tuple-related functionality for cute::array +// + +#include + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +T& get(array_aligned& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T const& get(array_aligned const& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T&& get(array_aligned&& a) +{ + static_assert(I < N, "Index out of range"); + return std::move(a[I]); +} + +} // end namespace cute + +namespace std +{ + +template +struct tuple_size> + : std::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end std diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp new file mode 100644 index 0000000000..a217a671f7 --- /dev/null +++ b/include/cute/container/array_subbyte.hpp @@ -0,0 +1,613 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Statically sized array of elements that accommodates subbyte trivial types + in a packed storage. +*/ + +#pragma once + +#include + +#include // sizeof_bits + +namespace cute +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Statically sized array for any data type +template +class array_subbyte +{ + public: + + /// Number of total bits in the array + static constexpr int kSizeBits = sizeof_bits::value * N; + + /// Storage type + using Storage = typename std::conditional< + (kSizeBits % 32) == 0, + uint32_t, + typename std::conditional< + (kSizeBits % 16) == 0, + uint16_t, + uint8_t + >::type + >::type; + + + /// Number of logical elements per stored object + static constexpr int kElementsPerStoredItem = sizeof_bits::value / sizeof_bits::value; + + /// Number of storage elements + static constexpr std::size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; + + /// Bitmask for covering one item + static constexpr Storage bit_mask_ = ((Storage(1) << sizeof_bits::value) - 1); + + // + // C++ standard members with reference and iterator types omitted + // + + using value_type = T; + using pointer = value_type*; + using const_pointer = value_type const*; + + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + + // + // References + // + + /// Reference object inserts or extracts sub-byte items + class reference { + /// Pointer to storage element + Storage* ptr_; + + /// Index into elements packed into Storage object + int idx_; + + public: + + /// Default ctor + CUTE_HOST_DEVICE constexpr + reference() : ptr_(nullptr), idx_(0) {} + + /// Ctor + CUTE_HOST_DEVICE constexpr + reference(Storage* ptr, int idx = 0) : ptr_(ptr), idx_(idx) {} + + /// Assignment + CUTE_HOST_DEVICE constexpr + reference& operator=(T x) { + Storage item = (reinterpret_cast(x) & bit_mask_); + Storage kUpdateMask = Storage(~(bit_mask_ << (idx_ * sizeof_bits::value))); + *ptr_ = Storage((*ptr_ & kUpdateMask) | (item << (idx_ * sizeof_bits::value))); + return *this; + } + + CUTE_HOST_DEVICE constexpr + T get() const { + Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & bit_mask_); + return reinterpret_cast(item); + } + + /// Extract to type T -- disable if T == bool + template ::value)> + CUTE_HOST_DEVICE constexpr + operator T() const { + return get(); + } + + // Extract to bool -- potentially faster impl + CUTE_HOST_DEVICE constexpr + operator bool() const { + return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits::value))); + } + + /// Explicit cast to int + CUTE_HOST_DEVICE constexpr + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to float + CUTE_HOST_DEVICE constexpr + explicit operator float() const { + return float(get()); + } + }; + + /// Reference object extracts sub-byte items + class const_reference { + + /// Pointer to storage element + Storage const* ptr_; + + /// Index into elements packed into Storage object + int idx_; + + public: + + /// Default ctor + CUTE_HOST_DEVICE constexpr + const_reference(): ptr_(nullptr), idx_(0) { } + + /// Ctor + CUTE_HOST_DEVICE constexpr + const_reference(Storage const* ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + + CUTE_HOST_DEVICE constexpr + const T get() const { + Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & bit_mask_); + return reinterpret_cast(item); + } + + /// Extract to type T -- disable if T == bool + template ::value)> + CUTE_HOST_DEVICE constexpr + operator T() const { + return get(); + } + + // Extract to bool -- potentially faster impl + CUTE_HOST_DEVICE constexpr + operator bool() const { + return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits::value))); + } + + /// Explicit cast to int + CUTE_HOST_DEVICE constexpr + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to float + CUTE_HOST_DEVICE constexpr + explicit operator float() const { + return float(get()); + } + }; + + // + // Iterators + // + + /// Bidirectional iterator over elements + class iterator { + + /// Pointer to storage element + Storage* ptr_; + + /// Index into elements packed into Storage object + int idx_; + + public: + + CUTE_HOST_DEVICE constexpr + iterator(): ptr_(nullptr), idx_(0) { } + + CUTE_HOST_DEVICE constexpr + iterator(Storage* ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + + CUTE_HOST_DEVICE constexpr + iterator& operator++() { + ++idx_; + if (idx_ == kElementsPerStoredItem) { + ++ptr_; + idx_ = 0; + } + return *this; + } + + CUTE_HOST_DEVICE constexpr + iterator& operator--() { + if (idx_) { + --idx_; + } else { + --ptr_; + idx_ = kElementsPerStoredItem - 1; + } + return *this; + } + + CUTE_HOST_DEVICE constexpr + iterator operator++(int) { + iterator ret(*this); + ++(*this); + return ret; + } + + CUTE_HOST_DEVICE constexpr + iterator operator--(int) { + iterator ret(*this); + --(*this); + return ret; + } + + CUTE_HOST_DEVICE constexpr + iterator& operator+=(int k) { + idx_ += k; + ptr_ += idx_ / kElementsPerStoredItem; + idx_ = idx_ % kElementsPerStoredItem; + return *this; + } + + CUTE_HOST_DEVICE constexpr + iterator operator+(int k) const { + return iterator(ptr_,idx_) += k; + } + + CUTE_HOST_DEVICE constexpr + reference operator*() const { + return reference(ptr_, idx_); + } + + CUTE_HOST_DEVICE constexpr + reference operator[](int k) const { + return *(*this + k); + } + + CUTE_HOST_DEVICE constexpr + bool operator==(iterator const& other) const { + return ptr_ == other.ptr_ && idx_ == other.idx_; + } + + CUTE_HOST_DEVICE constexpr + bool operator!=(iterator const& other) const { + return !(*this == other); + } + }; + + /// Bidirectional constant iterator over elements + class const_iterator { + + /// Pointer to storage element + Storage const* ptr_; + + /// Index into elements packed into Storage object + int idx_; + + public: + + CUTE_HOST_DEVICE constexpr + const_iterator(): ptr_(nullptr), idx_(0) { } + + CUTE_HOST_DEVICE constexpr + const_iterator(Storage const* ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + + CUTE_HOST_DEVICE constexpr + const_iterator& operator++() { + ++idx_; + if (idx_ == kElementsPerStoredItem) { + ++ptr_; + idx_ = 0; + } + return *this; + } + + CUTE_HOST_DEVICE constexpr + const_iterator& operator--() { + if (idx_) { + --idx_; + } else { + --ptr_; + idx_ = kElementsPerStoredItem - 1; + } + return *this; + } + + CUTE_HOST_DEVICE constexpr + const_iterator operator++(int) { + iterator ret(*this); + ++idx_; + if (idx_ == kElementsPerStoredItem) { + ++ptr_; + idx_ = 0; + } + return ret; + } + + CUTE_HOST_DEVICE constexpr + const_iterator operator--(int) { + iterator ret(*this); + if (idx_) { + --idx_; + } else { + --ptr_; + idx_ = kElementsPerStoredItem - 1; + } + return ret; + } + + CUTE_HOST_DEVICE constexpr + const_iterator& operator+=(int k) { + idx_ += k; + ptr_ += idx_ / kElementsPerStoredItem; + idx_ = idx_ % kElementsPerStoredItem; + return *this; + } + + CUTE_HOST_DEVICE constexpr + const_iterator operator+(int k) const { + return const_iterator(ptr_,idx_) += k; + } + + CUTE_HOST_DEVICE constexpr + const_reference operator*() const { + return const_reference(ptr_, idx_); + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](int k) const { + return *(*this + k); + } + + CUTE_HOST_DEVICE constexpr + bool operator==(iterator const& other) const { + return ptr_ == other.ptr_ && idx_ == other.idx_; + } + + CUTE_HOST_DEVICE constexpr + bool operator!=(iterator const& other) const { + return !(*this == other); + } + }; + +private: + + /// Internal storage + Storage storage[kStorageElements]; + +public: + + CUTE_HOST_DEVICE constexpr + array_subbyte() { } + + CUTE_HOST_DEVICE constexpr + array_subbyte(array_subbyte const& x) { + CUTE_UNROLL + for (unsigned i = 0; i < kStorageElements; ++i) { + storage[i] = x.storage[i]; + } + } + + CUTE_HOST_DEVICE constexpr + size_type size() const { + return N; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const { + return N; + } + + CUTE_HOST_DEVICE constexpr + bool empty() const { + return !N; + } + + /// Efficient clear method + CUTE_HOST_DEVICE constexpr + void clear() { + CUTE_UNROLL + for (unsigned i = 0; i < kStorageElements; ++i) { + storage[i] = Storage(0); + } + } + + // Efficient fill method + CUTE_HOST_DEVICE constexpr + void fill(T const& value) { + Storage item = (reinterpret_cast(value) & bit_mask_); + + // Reproduce the value over the bits of the storage item + CUTE_UNROLL + for (unsigned s = sizeof_bits::value; s < sizeof_bits::value; s *= 2) { + item |= item << s; + } + + CUTE_UNROLL + for (unsigned i = 0; i < kStorageElements; ++i) { + storage[i] = item; + } + } + + CUTE_HOST_DEVICE constexpr + reference at(size_type pos) { + return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); + } + + CUTE_HOST_DEVICE constexpr + const_reference at(size_type pos) const { + return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); + } + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) { + return at(pos); + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const { + return at(pos); + } + + CUTE_HOST_DEVICE constexpr + reference front() { + return at(0); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const { + return at(0); + } + + CUTE_HOST_DEVICE constexpr + reference back() { + return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const { + return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); + } + + CUTE_HOST_DEVICE constexpr + pointer data() { + return reinterpret_cast(storage); + } + + CUTE_HOST_DEVICE constexpr + const_pointer data() const { + return reinterpret_cast(storage); + } + + CUTE_HOST_DEVICE constexpr + Storage* raw_data() { + return storage; + } + + CUTE_HOST_DEVICE constexpr + Storage const* raw_data() const { + return storage; + } + + CUTE_HOST_DEVICE constexpr + iterator begin() { + return iterator(storage); + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const { + return const_iterator(storage); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + iterator end() { + return iterator(storage + N / kElementsPerStoredItem, N % kElementsPerStoredItem); + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const { + return const_iterator(storage + N / kElementsPerStoredItem, N % kElementsPerStoredItem); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const { + return end(); + } + + // + // Comparison operators + // + +}; + +// +// Operators +// + +template +CUTE_HOST_DEVICE constexpr +void clear(array_subbyte& a) +{ + a.clear(); +} + +template +CUTE_HOST_DEVICE constexpr +void fill(array_subbyte& a, T const& value) +{ + a.fill(value); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + +// +// Specialize tuple-related functionality for cute::array_subbyte +// + +#include + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +T& get(array_subbyte& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T const& get(array_subbyte const& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T&& get(array_subbyte&& a) +{ + static_assert(I < N, "Index out of range"); + return std::move(a[I]); +} + +} // end namespace cute + +namespace std +{ + +template +struct tuple_size> + : std::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace std diff --git a/include/cute/container/array_view.hpp b/include/cute/container/array_view.hpp new file mode 100644 index 0000000000..51b3ccc07d --- /dev/null +++ b/include/cute/container/array_view.hpp @@ -0,0 +1,274 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template +struct array_view +{ + using value_type = T; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + using iterator = pointer; + using const_iterator = const_pointer; + + array_view(array& a) + : __elems_(a.data()) {} + + CUTE_HOST_DEVICE + reference operator[](size_type pos) + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE + const_reference operator[](size_type pos) const + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE + reference front() + { + return *begin(); + } + + CUTE_HOST_DEVICE + const_reference front() const + { + return *begin(); + } + + CUTE_HOST_DEVICE + reference back() + { + // return *rbegin(); + return operator[](N-1); + } + + CUTE_HOST_DEVICE + const_reference back() const + { + // return *rbegin(); + return operator[](N-1); + } + + CUTE_HOST_DEVICE + T* data() + { + return __elems_; + } + + CUTE_HOST_DEVICE + const T* data() const + { + return __elems_; + } + + CUTE_HOST_DEVICE + iterator begin() + { + return data(); + } + + CUTE_HOST_DEVICE + const_iterator begin() const + { + return data(); + } + + CUTE_HOST_DEVICE + const_iterator cbegin() + { + return begin(); + } + + CUTE_HOST_DEVICE + const_iterator cbegin() const + { + return begin(); + } + + CUTE_HOST_DEVICE + iterator end() + { + return data() + size(); + } + + CUTE_HOST_DEVICE + const_iterator end() const + { + return data() + size(); + } + + CUTE_HOST_DEVICE + const_iterator cend() + { + return end(); + } + + CUTE_HOST_DEVICE + const_iterator cend() const + { + return end(); + } + + CUTE_HOST_DEVICE constexpr + bool empty() const + { + return size() == 0; + } + + CUTE_HOST_DEVICE constexpr + size_type size() const + { + return N; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const + { + return size(); + } + + CUTE_HOST_DEVICE + void fill(const T& value) + { + for(auto& e : *this) + { + e = value; + } + } + + CUTE_HOST_DEVICE + void swap(array_view& other) + { + using std::swap; + swap(__elems_, other.__elems_); + } + + value_type* __elems_; +}; + + +template +CUTE_HOST_DEVICE +bool operator==(const array_view& lhs, const array_view& rhs) +{ + for(std::size_t i = 0; i < N; ++i) + { + if(lhs[i] != rhs[i]) return false; + } + + return true; +} + +template +CUTE_HOST_DEVICE +void clear(array_view& a) +{ + a.fill(T(0)); +} + +template +CUTE_HOST_DEVICE +void swap(array_view& a, array_view& b) +{ + a.swap(b); +} + +} // end cute + + +// +// Specialize tuple-related functionality for cute::array_view +// + +#include + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +T& +get(array_view& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +const T& +get(const array_view& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T&& +get(array_view&& a) +{ + static_assert(I < N, "Index out of range"); + return std::move(a[I]); +} + +} // end namespace cute + +namespace std +{ + +template +struct tuple_size> + : std::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end std diff --git a/include/cute/container/bit_field.hpp b/include/cute/container/bit_field.hpp new file mode 100644 index 0000000000..06b08754c9 --- /dev/null +++ b/include/cute/container/bit_field.hpp @@ -0,0 +1,131 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Portable bit field that supports byte and word straddling that can + be used in unions to bit-wise define parameters. +*/ + +#pragma once + +#include + +#include // uint_bit_t + +namespace cute +{ + +class dummy_type {}; + +template +struct bit_field +{ + static_assert(0 < NumBits && NumBits <= 64, "bit_fields with more than 64 bits are not supported."); + + // value_type: Use the smallest value type that fits NumBits + static constexpr uint32_t value_type_bits = (NumBits <= 8) ? 8 : + (NumBits <= 16) ? 16 : + (NumBits <= 32) ? 32 : 64; + using value_type = cute::uint_bit_t; + // storage_type: Use the smallest storage_type that avoids boundary crossing + static constexpr uint32_t storage_type_bits = (BitStart / 8 == (BitStart + NumBits - 1) / 8) ? 8 : + (BitStart / 16 == (BitStart + NumBits - 1) / 16) ? 16 : + (BitStart / 32 == (BitStart + NumBits - 1) / 32) ? 32 : 64; + using storage_type = cute::uint_bit_t; + + static_assert(sizeof(OtherValueType) == sizeof(value_type) || std::is_same::value, + "sizeof(OtherValueType) must be same as sizeof(value_type)."); + + // Number of storage values needed: ceil_div(BitStart + NumBits, storage_type_bits) + static constexpr uint32_t N = (BitStart + NumBits + storage_type_bits - 1) / storage_type_bits; + // Index of storage value for BitStart + static constexpr uint32_t idx = BitStart / storage_type_bits; + // Bit of data_[idx] for BitStart + static constexpr uint32_t bit_lo = BitStart % storage_type_bits; + // Number of bits in data_[idx] used for NumBits if straddling, else 0 + static constexpr uint32_t bit_hi = (idx + 1 < N) ? (storage_type_bits - bit_lo) : 0; + + // NumBits mask + static constexpr value_type mask = (NumBits < 64) ? ((uint64_t(1) << NumBits) - 1) : uint64_t(-1); + // NumBits mask for BitStart + static constexpr storage_type mask_lo = storage_type(mask) << bit_lo; + // NumBits mask for leftover bits in data_[idx+1] if straddling, else 0 + static constexpr storage_type mask_hi = (idx + 1 < N) ? (storage_type(mask) >> bit_hi) : 0; + + storage_type data_[N]; + + // Get value + CUTE_HOST_DEVICE constexpr + value_type get() const { + storage_type result = (data_[idx] & mask_lo) >> bit_lo; + if constexpr (bit_hi) { + result |= (data_[idx+1] & mask_hi) << bit_hi; + } + return static_cast(result); + } + + // Set value + CUTE_HOST_DEVICE constexpr + void set(value_type x) { + storage_type item = static_cast(x & mask); + data_[idx] = static_cast((data_[idx] & ~mask_lo) | (item << bit_lo)); + if constexpr (bit_hi) { + data_[idx+1] = static_cast((data_[idx+1] & ~mask_hi) | (item >> bit_hi)); + } + } + + // Assign value + CUTE_HOST_DEVICE constexpr + bit_field& operator=(value_type x) { + set(x); + return *this; + } + + // Cast to value + CUTE_HOST_DEVICE constexpr + operator value_type () const { + return get(); + } + + // Assign OtherValueType + CUTE_HOST_DEVICE constexpr + bit_field& operator=(OtherValueType x) { + return *this = *reinterpret_cast(&x); + } + + // Cast to OtherValueType + CUTE_HOST_DEVICE constexpr + operator OtherValueType () const { + value_type x = get(); + return *reinterpret_cast(&x); + } +}; + +} // end namespace cute diff --git a/include/cute/container/tuple.hpp b/include/cute/container/tuple.hpp new file mode 100644 index 0000000000..1b3ffa42d4 --- /dev/null +++ b/include/cute/container/tuple.hpp @@ -0,0 +1,671 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include +#include + +#include // cute::true_type, cute::false_type +//#include // Advanced optimizations + +#if 0 +// +// Use of agency::tuple is functional, but is over-engineered for our purposes... +// This tends to result in slow compilation times and unintentionally propagated cvref types +// + +#include + +namespace cute +{ + +using agency::tuple; + +using agency::make_tuple; +using agency::tuple_cat; + +} // end namespace cute +#endif + +// cute::tuple is like std::tuple, with two differences. +// +// 1. It works on both host and device. +// 2. Its template arguments must be semiregular types. +// +// Semiregular types are default constructible and copyable. +// They include "value types" like int or float, +// but do _not_ include references like int& or float&. +// (See std::tie for an example of a tuple of references.) +// +// This is simplified over the implementation in std:: and agency:: by ignoring much of +// the conversion SFINAE, special overloading, and avoiding cvref template types. +// Furthermore, the empty base optimization (EBO) is MORE aggressive by avoiding +// construction calls, and ignoring any need for unique element addresses. +// +// Over the agency::tuple implementation, this appears to accelerate compilation times by over 3x. + +namespace cute +{ + +namespace detail +{ + +// EBO stands for "empty base optimization." +// We use this technique to ensure that cute::tuple +// doesn't need to waste space storing any template arguments +// of cute::tuple that have no data (like integral_constant). +// Otherwise, cute::tuple would need to spend at least 1 byte +// for each of its template arguments. +// +// EBO always "holds" a single value of type T. +// N is like an array index that TupleBase uses +// to access the desired tuple element. +template ::value> +struct EBO; + +// Specialization for types T that have no data; +// the "static tuple leaf." Valid T here include +// integral_constant, Int, +// and any other semiregular type +// for which std::is_empty_v is true. +template +struct EBO +{ + CUTE_HOST_DEVICE constexpr + EBO() {} + + CUTE_HOST_DEVICE constexpr + EBO(T const&) {} +}; + +template +CUTE_HOST_DEVICE constexpr T getv(EBO const&) +{ return {}; } + +// Specialization for types T that are not empty; +// the "dynamic tuple leaf." Valid T here include int, +// any other integral or floating-point type, +// or any semiregular type for which std::is_empty_v is false. +template +struct EBO +{ + CUTE_HOST_DEVICE constexpr + EBO() : t_{} {} + + template + CUTE_HOST_DEVICE constexpr + EBO(U const& u) : t_{u} {} + + T t_; +}; + +template +CUTE_HOST_DEVICE constexpr T const& getv(EBO const& x) +{ return x.t_; } + +template +CUTE_HOST_DEVICE constexpr T& getv(EBO& x) +{ return x.t_; } + +template +CUTE_HOST_DEVICE constexpr T&& getv(EBO&& x) +{ return static_cast(x.t_); } + +template +struct TupleBase; + +// Base class of cute::tuple. +// It inherits from EBO for each (i, t) in (I..., T...). +// The actual storage (for nonempty t) lives in the base classes. +// index_sequence is a way to wrap up a sequence of zero or more +// compile-time integer values in a single type. +// We only ever use index_sequence<0, 1, ..., sizeof...(T)> in practice, +// as the type alias TupleBase below indicates. +template +struct TupleBase, T...> + : EBO... +{ + CUTE_HOST_DEVICE constexpr + TupleBase() {} + + template + CUTE_HOST_DEVICE constexpr explicit + TupleBase(U const&... u) + : EBO(u)... {} + + template + CUTE_HOST_DEVICE constexpr + TupleBase(TupleBase, U...> const& u) + : EBO(getv(static_cast const&>(u)))... {} +}; + +} // end namespace detail + +// make_index_sequence returns index_sequence<0, 1, ..., K-1>. +template +using TupleBase = detail::TupleBase, T...>; + +// This is the actual cute::tuple class. +// The storage (if any) lives in TupleBase's EBO base classes. +template +struct tuple : TupleBase +{ + CUTE_HOST_DEVICE constexpr + tuple() {} + + template + CUTE_HOST_DEVICE constexpr + tuple(U const&... u) : TupleBase(u...) {} + + template + CUTE_HOST_DEVICE constexpr + tuple(tuple const& u) + : TupleBase(static_cast const&>(u)) {} +}; + +// +// get for cute::tuple (just like std::get for std::tuple) +// + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple const& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple&& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(static_cast&&>(t)); +} + +// +// Custom is_tuple trait simply checks the existence of std::tuple_size +// and assumes std::get(.), std::tuple_element +// +namespace detail { + +template +std::integral_constant::value >= 0> has_tuple_size(int); + +template +std::false_type has_tuple_size(...); + +} // end namespace detail + +template +struct is_tuple : decltype(detail::has_tuple_size(0)) {}; + +// +// make_tuple (value-based implementation) +// + +template +CUTE_HOST_DEVICE constexpr +tuple +make_tuple(T const&... t) +{ + return {t...}; +} + +// +// tuple_cat concatenates multiple cute::tuple into a single cute::tuple, +// just like std::tuple_cat for std::tuple. +// + +#if 0 +// Original implementation + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + std::index_sequence, std::index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + return detail::tuple_cat(t0, t1, + std::make_index_sequence::value>{}, + std::make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, Ts const&... ts) +{ + return cute::tuple_cat(cute::tuple_cat(t0,t1),t2,ts...); +} +#endif + +#if 1 +// Extended implementation + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + std::index_sequence, std::index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, + std::index_sequence, std::index_sequence, std::index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, + std::index_sequence, std::index_sequence, std::index_sequence, std::index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, + std::index_sequence, std::index_sequence, std::index_sequence, std::index_sequence, std::index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)..., get(t4)...); +} + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + return detail::tuple_cat(t0, t1, + std::make_index_sequence::value>{}, + std::make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2) +{ + return detail::tuple_cat(t0, t1, t2, + std::make_index_sequence::value>{}, + std::make_index_sequence::value>{}, + std::make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3) +{ + return detail::tuple_cat(t0, t1, t2, t3, + std::make_index_sequence::value>{}, + std::make_index_sequence::value>{}, + std::make_index_sequence::value>{}, + std::make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4) +{ + return detail::tuple_cat(t0, t1, t2, t3, t4, + std::make_index_sequence::value>{}, + std::make_index_sequence::value>{}, + std::make_index_sequence::value>{}, + std::make_index_sequence::value>{}, + std::make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, T5 const& t5, Ts const&... ts) +{ + return cute::tuple_cat(cute::tuple_cat(t0,t1,t2,t3,t4), t5, ts...); +} +#endif + +#if 0 +// Outer-Inner indexing trick to concat all tuples at once + +namespace detail { + +template +struct tuple_cat_helper +{ + static constexpr cute::array ns = {Ns...}; + + static constexpr std::size_t total_size() { + std::size_t sum = 0; + for (std::size_t n : ns) sum += n; + return sum; + } + static constexpr std::size_t total_size_ = total_size(); + + static constexpr auto values() { + cute::array outer_inner = {}; + + std::size_t idx = 0; + for (std::size_t i = 0; i < ns.size(); ++i) { + for (std::size_t j = 0; j < ns[i]; ++j, ++idx) { + outer_inner[idx][0] = i; + outer_inner[idx][1] = j; + } + } + return outer_inner; + } + static constexpr auto outer_inner_ = values(); + + using total_sequence = std::make_index_sequence; +}; + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(Tuple const& t, std::index_sequence) +{ + return cute::make_tuple(get(get(t))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + std::index_sequence, std::index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + return detail::tuple_cat(t0, t1, + std::make_index_sequence::value>{}, + std::make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(Tuples const&... ts) +{ + using Helper = detail::tuple_cat_helper::value...>; + return detail::tuple_cat(make_tuple(ts...), typename Helper::total_sequence{}); +} +#endif + +// +// Equality operators +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +equal_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == std::tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted + } else if constexpr (I == std::tuple_size::value) { + return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted + } else { + return (get(a) == get(b)) && equal_impl(a,b); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template ::value && is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator==(TupleT const& t, TupleU const& u) +{ + return detail::equal_impl<0>(t, u); +} + +template ::value ^ is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator==(TupleT const& t, TupleU const& u) +{ + return cute::false_type{}; +} + +template ::value && is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator!=(TupleT const& t, TupleU const& u) +{ + return !(t == u); +} + +template ::value ^ is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator!=(TupleT const& t, TupleU const& u) +{ + return cute::true_type{}; +} + +// +// Comparison operators +// + +// +// There are many ways to compare tuple of elements and because CuTe is built +// on parameterizing layouts of coordinates, some comparisons are appropriate +// only in certain cases. +// -- lexicographical comparison [reverse, reflected, revref] +// -- colexicographical comparison [reverse, reflected, revref] +// -- element-wise comparison [any,all] +// This can be very confusing. To avoid errors in selecting the appropriate +// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. +// +// That said, see int_tuple for more explicitly named common comparison ops. +// + +// +// Shortcuts +// + +//using std::get; +using std::tuple_size; +using std::tuple_element; +using std::tuple_element_t; + +// +// Display utilities +// + +namespace detail { + +template +CUTE_HOST_DEVICE void print_tuple(Tuple const& t, + std::index_sequence, char s = '(', char e = ')') +{ + using eat = int[]; + using cute::print; + (void) eat {(print(s), 0), + (print(Is == 0 ? "" : ","), print(get(t)), 0)..., + (print(e), 0)}; +} + +template +CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, + std::index_sequence, char s = '(', char e = ')') +{ + using eat = int[]; + (void) eat {(void(os << s), 0), + (void(os << (Is == 0 ? "" : ",") << get(t)), 0)..., + (void(os << e), 0)}; + return os; +} + +} // end namespace detail + +template ::value)> +CUTE_HOST_DEVICE void print(Tuple const& t) +{ + return detail::print_tuple(t, std::make_index_sequence::value>{}); +} + +template ::value)> +CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t) +{ + return detail::print_tuple_os(os, t, std::make_index_sequence::value>{}); +} + +} // end namespace cute + +// +// std:: compatability +// + +namespace std +{ + +template +struct tuple_size> + : std::integral_constant +{}; + +template +struct tuple_element> + : std::tuple_element> +{}; + +} // end std diff --git a/include/cute/container/type_list.hpp b/include/cute/container/type_list.hpp new file mode 100644 index 0000000000..c082a6daaf --- /dev/null +++ b/include/cute/container/type_list.hpp @@ -0,0 +1,84 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +namespace cute +{ + +template +struct type_c { + using type = T; +}; + +template +struct type_list {}; + +} // end namespace cute + +// +// Specialize tuple-related functionality for cute::type_list +// + +#include +#include + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +std::tuple_element_t> +get(type_list&) noexcept { + return {}; +} +template +CUTE_HOST_DEVICE constexpr +std::tuple_element_t> +get(type_list const& t) noexcept { + return {}; +} + +} // end namespace cute + +namespace std +{ + +template +struct tuple_size> + : std::integral_constant +{}; + +template +struct tuple_element> + : cute::type_c>::type> +{}; + +} // end namespace std diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp new file mode 100644 index 0000000000..045e7210b1 --- /dev/null +++ b/include/cute/int_tuple.hpp @@ -0,0 +1,827 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include + +namespace cute +{ + +template +using IntTuple = cute::tuple; + +// Construct an IntTuple with all value-elements +template +CUTE_HOST_DEVICE constexpr +IntTuple +make_int_tuple(Ts const&... t) +{ + return {t...}; +} + +/** if rank(int) == 1, then get<0>(int) should work too + */ +template >::value)> +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(T&& t) noexcept +{ + static_assert(I == 0, "Index out of range"); + return static_cast(t); +} + +/** Custom recursive get for anything that implements get(.) + */ +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(Tuple&& t) noexcept +{ + return get(get(static_cast(t))); +} + +// +// rank +// + +template +CUTE_HOST_DEVICE constexpr +auto +rank(IntTuple const& t) +{ + if constexpr (sizeof...(Is) == 0) { + if constexpr (is_tuple::value) { + return Int::value>{}; + } else { + return Int<1>{}; + } + } else { + return rank(get(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using rank_t = decltype(rank(std::declval())); + +template +static constexpr int rank_v = rank_t::value; + +// +// shape +// + +template +CUTE_HOST_DEVICE constexpr +auto +shape(IntTuple const& s) +{ + if constexpr (is_tuple::value) { + return transform(s, [](auto const& a) { return shape(a); }); + } else { + return s; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +shape(IntTuple const& s) +{ + if constexpr (is_tuple::value) { + return shape(get(s)); + } else { + return get(shape(s)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// max +// + +template +CUTE_HOST_DEVICE constexpr +auto +max(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::max(cute::apply(t0, [](auto const&... a){ return cute::max(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::max(t0, cute::max(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// min +// + +template +CUTE_HOST_DEVICE constexpr +auto +min(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::min(cute::apply(t0, [](auto const&... a){ return cute::min(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::min(t0, cute::min(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// depth +// + +template +CUTE_HOST_DEVICE constexpr +auto +depth(IntTuple const& t) +{ + if constexpr (sizeof...(Is) == 0) { + if constexpr (is_tuple::value) { + return Int<1>{} + cute::apply(t, [](auto const&... v){ return cute::max(depth(v)...); }); + } else { + return Int<0>{}; + } + } else { + return depth(get(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using depth_t = decltype(depth(std::declval())); + +template +static constexpr int depth_v = depth_t::value; + +// +// product +// + +template +CUTE_HOST_DEVICE constexpr +auto +product(IntTuple const& a) +{ + if constexpr (is_tuple::value) { + return cute::apply(a, [](auto const&... v){ return (Int<1>{} * ... * product(v)); }); + } else { + return a; + } + + CUTE_GCC_UNREACHABLE; +} + +// Product of a subrange +template +CUTE_HOST_DEVICE constexpr +auto +product(Tuple const& a) +{ + return detail::apply(a, [](auto const&... v){ return (Int<1>{} * ... * product(v)); }, make_range{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +product_each(Tuple const& t) +{ + return transform(t, [](auto const& x) { return product(x); }); +} + +// Return the product of elements in a mode +template +CUTE_HOST_DEVICE constexpr +auto +size(IntTuple const& a) +{ + if constexpr (sizeof...(Is) == 0) { + return product(a); + } else { + return product(get(a)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +static constexpr int size_v = decltype(size(std::declval()))::value; + +// +// sum +// + +template +CUTE_HOST_DEVICE constexpr +auto +sum(IntTuple const& a) +{ + if constexpr (is_tuple::value) { + return cute::apply(a, [](auto const&... v){ return (Int<0>{} + ... + sum(v)); }); + } else { + return a; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// inner_product +// + +template +CUTE_HOST_DEVICE constexpr +auto +inner_product(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform_apply(a, b, [](auto const& x, auto const& y) { return inner_product(x,y); }, + [](auto const&... v) { return (Int<0>{} + ... + v); }); + } else { + return a * b; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// ceil_div +// + +template +CUTE_HOST_DEVICE constexpr +auto +ceil_div(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); + constexpr int R = tuple_size::value; // Missing ranks in TupleB are implictly 1 + return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); }); + } else { + return (a + b - Int<1>{}) / b; + } + + CUTE_GCC_UNREACHABLE; +} + +/** Division for Shapes + */ +template +CUTE_HOST_DEVICE constexpr +auto +shape_div(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); }); + } else { // tuple int + auto const [result, rest] = fold(a, make_tuple(make_tuple(), b), + [] (auto const& init, auto const& ai) { + return make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai)); + }); + return result; + } + } else { + if constexpr (is_tuple::value) { // int tuple + return shape_div(a, product(b)); + } else { // int int + //assert(a % b == 0 || b % a == 0); + return a / b != 0 ? a / b : signum(a) * signum(b); // divide with rounding away from zero + } + } + + CUTE_GCC_UNREACHABLE; +} + +/** Division for Shapes that are static constants + * @pre t % u == 0 || u % t == 0 + * @result if t % u == 0, then t / u + * if u % t == 0, then signum(t) * signum(u) + */ +template +CUTE_HOST_DEVICE constexpr +constant +shape_div(constant const&, constant const&) +{ + static_assert(t % u == 0 || u % t == 0, "Static shape_div failure"); + return {}; +} + +/** Return a tuple the same profile as A scaled by corresponding elements in B + */ +template +CUTE_HOST_DEVICE constexpr +auto +elem_scale(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + return transform(a, b, [](auto const& x, auto const& y) { return elem_scale(x,y); }); + } else { + return a * product(b); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Test if two IntTuple have the same profile (hierarchical rank division) + */ +template +CUTE_HOST_DEVICE constexpr +auto +congruent(IntTupleA const& a, IntTupleB const& b) +{ + return bool_constant::value>{}; +} + +template +using is_congruent = decltype(congruent(std::declval(), std::declval())); + +/** Test if Shape B is compatible with Shape A: + * Any coordinate into A can also be used as a coordinate into B + * A <= B is a partially ordered set of factored shapes + */ +template +CUTE_HOST_DEVICE constexpr +auto +compatible(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + if constexpr (tuple_size::value != tuple_size::value) { + return false_type{}; + } else { + return transform_apply(a, b, [](auto const& x, auto const& y) { return compatible(x,y); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else if constexpr (is_integral::value) { + return a == size(b); + } else if constexpr (is_integral::value) { + return false_type{}; + } else { + return compatible(shape(a), shape(b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using is_compatible = decltype(compatible(std::declval(), std::declval())); + +/** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1> + */ +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value) { + return transform(a, b, [](auto const& x, auto const& y) { return filter_zeros(x,y); }); + } else if constexpr (is_constant<0, IntTupleA>::value) { + return Int<1>{}; + } else { + return b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tuple const& t) +{ + return filter_zeros(t, t); +} + +// +// Converters and constructors with arrays and params +// + +/** Make an IntTuple of rank N from an Indexable array. + * Access elements up to a dynamic index n, then use init (requires compatible types) + * Consider cute::take if all indexing is known to be valid + * \code + * std::vector a = {6,3,4}; + * auto tup = make_int_tuple<5>(a, a.size(), 0) // (6,3,4,0,0) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_int_tuple(Indexable const& t, int n, T const& init) +{ + static_assert(N > 0); + if constexpr (N == 1) { + return 0 < n ? t[0] : init; + } else { + return transform(make_seq{}, [&](auto i) { return i < n ? t[i] : init; }); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Fill the dynamic values of a Tuple with values from another Tuple + * \code + * auto params = make_int_tuple(6,3,4); + * cute::tuple, cute::tuple>, int, Int<2>> result; + * fill_int_tuple_from(result, params); // (_1,(6,3,_3),4,_2) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +fill_int_tuple_from(Tuple& result, TupleV const& vals) +{ + return fold(result, vals, [](auto const& init, auto&& r) { + if constexpr (is_static>::value) { // Skip static elements of result + return init; + } else if constexpr (is_tuple>::value) { // Recurse into tuples + return fill_int_tuple_from(r, init); + } else { // Assign and consume arg + static_assert(tuple_size>::value > 0, "Not enough values to fill with!"); + r = get<0>(init); + return remove<0>(init); + } + + CUTE_GCC_UNREACHABLE; + }); +} + +/** Make a "Tuple" by filling in the dynamic values in order from the arguments + * \code + * using result_t = cute::tuple, cute::tuple>, int, Int<2>>; + * auto result = make_int_tuple_from(6,3,4); // (_1,(6,3,_3),4,_2) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +Tuple +make_int_tuple_from(Ts const&... ts) +{ + Tuple result = Tuple{}; + fill_int_tuple_from(result, make_tuple(ts...)); + return result; +} + +/** Convert a tuple to a flat homogeneous array of type T + * \code + * auto tup = make_tuple(Int<1>{}, make_tuple(6,3,Int<3>{}),4,Int<2>{}); + * cute::array result = to_array(tup); // [1,6,3,3,4,2] + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +to_array(IntTuple const& t) +{ + auto flat_t = flatten_to_tuple(t); + constexpr int N = tuple_size::value; + cute::array result; + for_each(make_seq{}, [&] (auto i) { result[i] = get(flat_t); }); + return result; +} + +// +// Comparison operators +// + +// +// There are many ways to compare tuple of elements and because CuTe is built +// on parameterizing layouts of coordinates, some comparisons are appropriate +// only in certain cases. +// -- lexicographical comparison [reverse, reflected, revref] : Correct for coords in RowMajor Layout +// -- colexicographical comparison [reverse, reflected, revref] : Correct for coords in ColMajor Layout +// -- element-wise comparison [any,all] : +// This can be very confusing. To avoid errors in selecting the appropriate +// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. +// +// When actually desiring to order coordinates, the user should map them to +// their indices within the Layout they came from: +// e.g. layoutX(coordA) < layoutX(coordB) +// That said, we implement the three most common ways to compare tuples below. +// These are implemented with slighly more explicit names than op<. +// + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less(IntTupleA const& a, IntTupleB const& b); + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less(IntTupleA const& a, IntTupleB const& b); + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less(IntTupleA const& a, IntTupleB const& b); + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleB is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted + } else { + return lex_less(get(a), get(b)) || (get(a) == get(b) && lex_less_impl(a,b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleB is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted + } else { + constexpr std::size_t A = tuple_size::value - 1 - I; + constexpr std::size_t B = tuple_size::value - 1 - I; + return colex_less(get(a), get(b)) || (get(a) == get(b) && colex_less_impl(a,b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted + } else { + return elem_less(get(a), get(b)) && elem_less_impl(a,b); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Lexicographical comparison + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::lex_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_leq(T const& t, U const& u) { + return !lex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_gtr(T const& t, U const& u) { + return lex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_geq(T const& t, U const& u) { + return !lex_less(t, u); +} + +// Colexicographical comparison + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::colex_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_leq(T const& t, U const& u) { + return !colex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_gtr(T const& t, U const& u) { + return colex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_geq(T const& t, U const& u) { + return !colex_less(t, u); +} + +// Elementwise [all] comparison + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::elem_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_leq(T const& t, U const& u) { + return !elem_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_gtr(T const& t, U const& u) { + return elem_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_geq(T const& t, U const& u) { + return !elem_less(t, u); +} + +/** Increment a (dynamic) coord lexicographically within a shape + * \code + * auto shape = make_shape(1,2,make_shape(2,3),3); + * + * int i = 0; + * for (auto coord = repeat_like(shape, 0); back(coord) != back(shape); increment(coord, shape)) { + * std::cout << i++ << ": " << coord << std::endl; + * } + * assert(i == size(shape)); + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +void +increment(Coord& coord, Shape const& shape); + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +void +increment(Coord& coord, Shape const& shape, seq) +{ + cute::increment(get(coord), get(shape)); + if constexpr (sizeof...(Is) != 0) { + if (back(get(coord)) == back(get(shape))) { + back(get(coord)) = 0; + increment(coord, shape, seq{}); + } + } +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +void +increment(Coord& coord, Shape const& shape) +{ + if constexpr (is_integral::value && is_integral::value) { + ++coord; + } else if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + detail::increment(coord, shape, tuple_seq{}); + } else { + static_assert(sizeof(Coord) == 0, "Invalid parameters"); + } +} + +struct ForwardCoordIteratorSentinal +{}; + +// A forward iterator for a coordinate that starts from zero and goes to shape +template +struct ForwardCoordIterator +{ + static_assert(is_congruent::value); + + CUTE_HOST_DEVICE constexpr + Coord const& operator*() const { return coord; } + + CUTE_HOST_DEVICE constexpr + ForwardCoordIterator& operator++() { increment(coord, shape); return *this; } + + // Sentinal for the end of the implied range + CUTE_HOST_DEVICE constexpr + bool operator< (ForwardCoordIteratorSentinal const&) const { return back(coord) < back(shape); } + CUTE_HOST_DEVICE constexpr + bool operator==(ForwardCoordIteratorSentinal const&) const { return back(coord) == back(shape); } + CUTE_HOST_DEVICE constexpr + bool operator!=(ForwardCoordIteratorSentinal const&) const { return back(coord) != back(shape); } + // NOTE: These are expensive, avoid use + CUTE_HOST_DEVICE constexpr + bool operator< (ForwardCoordIterator const& other) const { return colex_less(coord, other.coord); } + CUTE_HOST_DEVICE constexpr + bool operator==(ForwardCoordIterator const& other) const { return coord == other.coord; } + CUTE_HOST_DEVICE constexpr + bool operator!=(ForwardCoordIterator const& other) const { return coord != other.coord; } + + Coord coord; + Shape const& shape; +}; + +// A forward iterator for a coordinate that starts from zero +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Shape const& shape) +{ + auto coord = repeat_like(shape, int(0)); + return ForwardCoordIterator{coord,shape}; +} + +} // end namespace cute diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp new file mode 100644 index 0000000000..fe937ee738 --- /dev/null +++ b/include/cute/layout.hpp @@ -0,0 +1,1638 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include + +namespace cute +{ + +// Aliases + +template +using Shape = IntTuple; + +template +using Stride = IntTuple; + +template +using Step = IntTuple; + +template +using Coord = IntTuple; + +template +CUTE_HOST_DEVICE constexpr +Shape +make_shape(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Stride +make_stride(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Step +make_step(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Coord +make_coord(Ts const&... t) { + return {t...}; +} + + +template > +struct Layout + : private cute::tuple // EBO for static layouts +{ + // Avoid bad CTAD: + // Layout smem = GMMA::Layout_MN_SW128_Atom; + // Should fail because smem is a ComposedLayout (SwizzleLayout) and not a Layout + static_assert(is_integral::value || is_tuple::value); + + // Expensive in compilation time... + //static_assert(is_congruent::value, + // "Shape and Stride must have the same hierarchical structure"); + //static_assert(is_integral::value || is_tuple::value); + + // NOTE: This defaults static Shapes/Strides correctly, but not dynamic + CUTE_HOST_DEVICE constexpr + Layout(LogicalShape const& logical_shape = {}, + LogicalStride const& logical_stride = {}) + : cute::tuple(logical_shape, logical_stride) + {} + + // + // Accessors + // + + static constexpr int rank = rank_v ; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return *this; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() { + return get<0,I...>(static_cast&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return get<0,I...>(static_cast const&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() { + return get<1,I...>(static_cast&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const { + return get<1,I...>(static_cast const&>(*this)); + } + + // + // Mappings + // + + // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + // OR + // Slice the layout and return the sublayout (Coord has an Underscore slice op) + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + return slice(coord, *this); + } else { + return crd2idx(coord, shape(), stride()); + } + + CUTE_GCC_UNREACHABLE; + } + + // Convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // Map a linear index to a hier ND logical coordinate + // NOTE: Dangerous and error-prone + template + CUTE_HOST_DEVICE constexpr + auto + operator[](Int const& linear_idx) const { + static_assert(is_integral::value); + return get_hier_coord(linear_idx); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(OtherLayout const& other) const { + return composition(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return composition(*this, make_tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(OtherShape const& shape) const { + return composition(*this, make_layout(shape)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(Shapes const&... shapes) const { + return composition(*this, make_layout(make_shape(shapes...))); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(OtherLayout const& other) const { + return tiled_divide(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return tiled_divide(*this, make_tile(layouts...)); + } + + // + // Utility + // + + // + // Index to Coordinate + // + + // NOTE: Only valid for compact layouts + + // Return the (hierarchical) ND logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post congruent(@a result, shape()) + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_hier_coord(IInt const& idx) const { + return cute::idx2crd(idx, shape(), stride()); + } + + // Return the (flat) ND logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_flat_coord(IInt const& idx) const { + return cute::crd2crd(this->get_hier_coord(idx), shape(), repeat(Int<1>{})); + } + + // Return the generalized column-major 1D logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post is_integral::value + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_1d_coord(IInt const& idx) const { + return cute::crd2idx(this->get_hier_coord(idx), shape()); + } + + // + // Coordinate to Coordinate + // + +#if 0 + // Return the (hierarchical) ND logical coordinate corresponding to the linear index + // @post congruent(@a result, shape()) + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_hier_coord(Coord const& crd) const { + return cute::crd2crd(crd, shape(), shape()); + } + + // Return the (flat) ND logical coordinate corresponding to the linear index + // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_flat_coord(Coord const& crd) const { + return cute::crd2crd(crd, shape(), product_each(shape())); + } + + // Return the generalized column-major 1D logical coordinate corresponding to the linear index + // @post is_integral::value + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_1d_coord(Coord const& crd) const { + //return cute::crd2crd(crd, shape(), product(shape())); + return cute::crd2idx(crd, shape()); + } +#endif +}; + + +template +struct is_layout : false_type {}; +template +struct is_layout> : true_type {}; + + +template ::value || is_integral::value) && + (is_tuple::value || is_integral::value))> +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, Stride const& stride) +{ + return Layout(shape, stride); +} + +template ::value || is_integral::value)> +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape) +{ + return make_layout(shape, compact_col_major(shape)); +} + +// Construct a layout from multiple layouts by +// concatenating each layout as an independent mode +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Layout const&... layouts) +{ + return make_layout(make_shape (layouts.shape()...), + make_stride(layouts.stride()...)); +} + +// +// Convenience tags for common layouts +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, GenColMajor) +{ + return make_layout(shape, compact_col_major(shape)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, GenRowMajor) +{ + return make_layout(shape, compact_row_major(shape)); +} + +// Follow the same ordering induced by the strides, but make the layout compact +template +CUTE_HOST_DEVICE constexpr +auto +make_ordered_layout(Shape const& shape, Order const& order) +{ + static_assert(is_static::value && is_static::value); + return make_layout(shape, compact_order(shape, order)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_ordered_layout(Layout const& layout) +{ + return make_ordered_layout(layout.shape(), layout.stride()); +} + +// Make a layout of the same shape that is either ordered or colmajor depending on staticness +template +CUTE_HOST_DEVICE constexpr +auto +make_layout_like(Layout const& layout) +{ + if constexpr (is_static::value && is_static::value) { + return make_ordered_layout(layout.shape(), layout.stride()); + } else { + return make_layout(layout.shape()); + } + + CUTE_GCC_UNREACHABLE; +} + +// Make a layout of the same shape, +// with mode-0 being colmajor then following the the mode order in layout +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Layout const& layout) +{ + auto shape = replace<0>(layout.shape(), size<0>(layout)); + auto order = replace<0>(layout.stride(), Int<0>{}); + if constexpr (is_static::value && is_static::value) { + return make_ordered_layout(shape, order); + } else { + return make_layout(layout.shape()); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_identity_layout(Shape const& shape) +{ + return make_layout(shape, make_basis_like(shape)); +} + +// +// Operations to manipulate Layouts like a tuple of pairs +// + +template +CUTE_HOST_DEVICE constexpr +auto +get(Layout const& layout) +{ + // Let the static_asserts in get(shape|stride) catch problems + return make_layout(get(layout.shape()), get(layout.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +take(Layout const& layout) +{ + // Let the static_asserts in take(shape|stride) catch problems + return make_layout(take(layout.shape()), take(layout.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flatten(Layout const& layout) +{ + return make_layout(flatten(layout.shape()), flatten(layout.stride())); +} + +// +// Utilities +// + +// Return the layout of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +layout(Layout const& layout) +{ + if constexpr (sizeof...(Is) == 0) { + return layout; + } else { + return get(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(Layout& layout) +{ + return layout.template shape(); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(Layout const& layout) +{ + return layout.template shape(); +} + +// Return the stride of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(Layout& layout) +{ + return layout.template stride(); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(Layout const& layout) +{ + return layout.template stride(); +} + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +auto +size(Layout const& layout) +{ + return size(shape(layout)); +} + +// Return the number of modes +template +CUTE_HOST_DEVICE constexpr +auto +rank(Layout const& layout) +{ + return rank(shape(layout)); +} + +// Return the depth of the layout +template +CUTE_HOST_DEVICE constexpr +auto +depth(Layout const& layout) +{ + return depth(shape(layout)); +} + +// Return the codomain size of a mode +// @return M smallest integer such that @a sub_layout(c) < M for all c < size(@a sub_layout) +// where sub_layout = get(layout). +template +CUTE_HOST_DEVICE constexpr +auto +cosize(Layout const& layout) +{ + // Protect against negative strides + auto abs_sub_layout = make_layout(shape(layout), + transform_leaf(stride(layout), abs_fn{})); + return abs_sub_layout(size(abs_sub_layout) - Int<1>{}) + Int<1>{}; +} + +template +using cosize_t = decltype(cosize(std::declval())); + +template +static constexpr int cosize_v = cosize_t::value; + +// Equality +// Return a static or dynamic boolean +template +CUTE_HOST_DEVICE constexpr +auto +operator==(Layout const& layoutA, Layout const& layoutB) +{ + return layoutA.shape() == layoutB.shape() && layoutA.stride() == layoutB.stride(); +} + +// With crd2idx(coord, shape), makes sense to have crd2idx(coord, Layout) as well +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx(Coord const& c, Layout const& layout) +{ + return crd2idx(c, layout.shape(), layout.stride()); +} + +// +// Slice and Dice a layout +// + +template +CUTE_HOST_DEVICE constexpr +auto +slice(Coord const& c, Layout const& layout) +{ + return make_layout(slice(c, layout.shape()), + slice(c, layout.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& c, Layout const& layout) +{ + return cute::make_tuple(slice(c, layout), crd2idx(c, layout)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +dice(Coord const& c, Layout const& layout) +{ + return make_layout(dice(c, layout.shape()), + dice(c, layout.stride())); +} + +// +// Transform the modes of a layout +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple const& t, F&& f, seq) +{ + return make_layout(f(get(t))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f, seq, seq, seq) +{ + return make_layout(f(get(t0),get(t1))..., get(t0)..., get(t1)...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple const& t, F&& f) +{ + return detail::transform_layout(t, f, make_seq{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f) +{ + constexpr int R0 = decltype(rank(t0))::value; + constexpr int R1 = decltype(rank(t1))::value; + constexpr int R = (R0 < R1) ? R0 : R1; + return detail::transform_layout(t0, t1, f, make_seq{}, make_range{}, make_range{}); +} + +// +// Coalesce and Filter +// + +namespace detail { + +// Look at each element and the front of the stack (in order of priority) +// front(NewLayout) get(Layout) +// s0:d0 _1:d1 => continue +// _1:d0 s1:d1 => replace_front s1:d1 +// s0:s1*d1 s1:d1 => replace_front s0*s1:d1 +// s0:d0 s1:d1 => prepend s1:d1 +// +// @pre OldShape and OldStride are flat +template +CUTE_HOST_DEVICE constexpr +auto +bw_coalesce(OldShape const& old_shape, OldStride const& old_stride, + NewShape const& new_shape, NewStride const& new_stride) +{ + if constexpr (I == -1) { + // Base case, we're done + if constexpr (is_constant<1, NewShape>::value) { + return Layout<_1,_0>{}; + } else { + return Layout{new_shape,new_stride}; + } + } else if constexpr (is_constant<1, decltype(get(old_shape))>::value) { + // shape(layout) == _1, skip it and continue + return bw_coalesce(old_shape, old_stride, new_shape, new_stride); + } else if constexpr (is_constant<1, NewShape>::value) { + // Replace our shape-1 with anything (Can only happen on input new_shape/new_stride) + return bw_coalesce(old_shape, old_stride, get(old_shape), get(old_stride)); + } else if constexpr (is_constant(old_shape) * get(old_stride) == get<0>(new_stride))>::value) { + // Merge modes because the shapes and strides match + return bw_coalesce(old_shape, old_stride, + replace_front(new_shape, get(old_shape) * get<0>(new_shape)), + replace_front(new_stride, get(old_stride))); + } else { + // Can't replace or merge, so prepend a new mode + return bw_coalesce(old_shape, old_stride, + prepend(new_shape, get(old_shape)), + prepend(new_stride, get(old_stride))); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Combine all the modes that are possible to combine +// Does not respect the profile of the layout, but does preserve total size +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + + constexpr int R = decltype(rank(flat_shape))::value; + return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); +} + +// Apply coalesce at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return coalesce(l,t); }); + } else { + return coalesce(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// Replace the modes in layout that have a 0-stride with a 1-size +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Layout const& layout) +{ + return make_layout(filter_zeros(layout.stride(), layout.shape()), layout.stride()); +} + +// Remove all of the 0-strides and 1-sizes +// Return 1-shape if empty +template +CUTE_HOST_DEVICE constexpr +auto +filter(Layout const& layout) +{ + return coalesce(filter_zeros(layout)); +} + +// Apply filter at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +filter(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return filter(l,t); }); + } else { + return filter(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Append, Prepend, Replace +// + +template +CUTE_HOST_DEVICE constexpr +auto +append(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(append(layout.shape(), x.shape()), + append(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(prepend(layout.shape(), x.shape()), + prepend(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +replace(Layout const& layout, + Layout const& x) +{ + return make_layout(replace(layout.shape(), x.shape()), + replace(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group(Layout const& layout) +{ + return make_layout(group(layout.shape()), + group(layout.stride())); +} + +// +// Composition of two layouts: lhs o rhs +// @post compatible(rhs, result) +// @post result(c) = lhs(rhs(c)) +// for all c in the domain of result +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& lhs, + RShape const& rhs_shape, RStride const& rhs_stride) +{ + if constexpr (is_tuple::value) { + // Apply the right-distributivity of Layout composition + return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) { return composition(lhs, s, d); }); + } else + if constexpr (is_scaled_basis::value) { + // Special case for a ScaledBasis stride + return composition(get(lhs), rhs_shape, rhs_stride.value()); + } else + if constexpr (is_integral::value) { + // Integral Rstride (and RShape) + + // NOTE: Should only flatten once for efficiency + auto flat_shape = flatten(lhs.shape()); + auto flat_stride = flatten(lhs.stride()); + [[maybe_unused]] constexpr int R = rank(flat_shape); + + if constexpr (is_constant<0, RStride>::value) { + // Special case shortcut for any static stride-0 + return Layout{rhs_shape, rhs_stride}; + } else + if constexpr (is_integral::value) { + // Special case shortcut for any integral LShape + auto result_stride = rhs_stride * flat_stride; + return Layout{rhs_shape, result_stride}; + } else + if constexpr (is_constant<1, RStride>::value) { + // Special case shortcut for any static stride-1 + auto result_shape_0 = take<0,R-1>(flat_shape); + + // Mod out the rhs_shape from the lhs.shape() + auto const [result_shape_1, rest_shape] = fold(result_shape_0, make_tuple(make_tuple(), rhs_shape), + [] (auto const& init, auto const& si) { + return make_tuple(append(get<0>(init), cute::min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); + }); + + // Jump into coalesce and append (rest_shape, get(lhs.stride()) + return detail::bw_coalesce(result_shape_1, flat_stride, rest_shape, get(flat_stride)); + } else + { + // General case + auto result_shape_0 = take<0,R-1>(flat_shape); + auto result_stride_0 = take<0,R-1>(flat_stride); + + // Divide out the rhs_stride from the lhs.shape() + auto const [result_shape_1, rest_stride] = fold(result_shape_0, make_tuple(make_tuple(), rhs_stride), + [] (auto const& init, auto const& di) { + return make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di)); + }); + + // Apply any lhs.shape() changes to the stride + auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1)); + + // Mod out the rhs_shape from the lhs.shape() + auto const [result_shape_2, rest_shape] = fold(result_shape_1, make_tuple(make_tuple(), rhs_shape), + [] (auto const& init, auto const& si) { + return make_tuple(append(get<0>(init), cute::min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); + }); + + // Jump into coalesce and append (rest_shape, rest_stride * get(lhs.stride()) + return detail::bw_coalesce(result_shape_2, result_stride_1, rest_shape, rest_stride * get(flat_stride)); + } + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& lhs, + Layout const& rhs) +{ + //return detail::composition(flatten(lhs), rhs.shape(), rhs.stride()); + return detail::composition(lhs, rhs.shape(), rhs.stride()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& lhs, + IntTuple const& rhs) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + // Drop any modes of lhs that aren't hit by rhs + return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq::value>{}, seq<>{}, seq<>{}); + } else if constexpr (is_underscore::value) { + return lhs; + } else { + return composition(lhs, make_layout(rhs)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Complement +// +// Build the complement of a layout. +// @post size(@a result) >= @a cosize_hi / size(filter(@a layout))); +// @post For all i in [1,size(@a result)), +// @a result(i) < @a result(i-1) +// For all j in [0, size(@a layout)), +// @a result(i) != @a layout(j) +// + +template +CUTE_HOST_DEVICE constexpr +auto +complement(Layout const& layout, CoSizeHi const& cosize_hi) +{ + // Remove the stride-0 modes, the size-1 modes, and flatten the layout + auto flat_layout = filter(layout); + + if constexpr (is_constant<0, decltype(flat_layout.stride())>::value) { + // Special case for stride-0 layout + return make_layout(cosize_hi); + } else { + // General case + constexpr int R = decltype(rank(flat_layout))::value; + static_assert(R == 1 || is_static::value, + "Dynamic-stride complement only for rank-1 layouts"); + + // Should just be a sort and a fold... + // Then we could even handle dynamic strides (but they would destroy all static strides) + auto result = fold(make_seq{}, + make_tuple(flat_layout.shape(), + flat_layout.stride(), + make_tuple(), + make_tuple(Int<1>{})), + [](auto const& init, auto i) + { + auto curr_stride = cute::min(get<1>(init)); + auto curr_idx = find(get<1>(init), curr_stride); + auto curr_shape = get(get<0>(init)); + + return make_tuple(remove(get<0>(init)), // Remove the curr shape + remove(get<1>(init)), // Remove the curr stride + append(get<2>(init), curr_stride / get<3,i>(init)), // new shape = curr_stride / last_stride + append(get<3>(init), curr_shape * curr_stride)); // new stride = curr_shape * curr_stride + }); + + // Append the last shape mode + auto result_stride = get<3>(result); + auto result_shape = append(get<2>(result), get<1,0>(result) / back(result_stride)); // new shape = curr_stride / last_stride + + // Compute the rest_stride + auto rest_stride = get<0,0>(result) * get<1,0>(result); + //return make_layout(append(result_shape, ceil_div(cosize_hi, rest_stride)), append(result_stride, rest_stride)); + // Jump into coalesce and append (ceil_div(cosize_hi, rest_stride), rest_stride) + return detail::bw_coalesce(result_shape, result_stride, ceil_div(cosize_hi, rest_stride), rest_stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +complement(Layout const& layout) +{ + return complement(layout, cosize(layout)); +} + +// +// Right-Inverse and Left-Inverse +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +inverse_seq(Shape const& shape, Stride const& stride, seq) +{ + if constexpr (I == decltype(rank(stride))::value) { + return seq{}; + } else { + //auto next_stride = get(shape) * get(stride); + using next_stride = decltype(get(shape) * get(stride)); // NOTE: WAR for g++-7 + + if constexpr (is_static::value) { + auto next_idx = find_if(stride, [](auto a) { return is_constant{}; }); + return inverse_seq(shape, stride, seq{}); + } else { + return seq{}; + } + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// +// Build the right-inverse of a layout +// @pre is_static +// @result A layout @a result such that +// @a layout(@a result(i)) == i for all i < size(@a result) +// @result A layout @a result such that +// composition(@a layout, @a result) is identical to make_layout(shape(result)) +// + +template +CUTE_HOST_DEVICE constexpr +auto +right_inverse(Layout const& layout) +{ + auto flat_layout = coalesce(layout); + auto astride = transform_leaf(flat_layout.stride(), abs_fn{}); + + // Find Int<1>{}, the starting idx, and follow the strides to gen inverse_seq + auto next_I = find_if(astride, [](auto a) { return is_constant<1, decltype(a)>{}; }); + [[maybe_unused]] auto iseq = detail::inverse_seq(flat_layout.shape(), astride, seq<>{}); + + if constexpr (tuple_size::value == 0) { + return Layout<_1,_0>{}; // Empty case, nothing found + } else { + // Generate the corresponding new strides and construct + auto rstride = compact_col_major(flat_layout.shape()); + return make_layout(unwrap(transform(iseq, [&](auto i) { return shape(flat_layout); })), + unwrap(transform(iseq, [&](auto i) { return signum(stride(flat_layout)) * get(rstride); }))); + } + + CUTE_GCC_UNREACHABLE; +} + +CUTE_HOST_DEVICE constexpr +auto +right_inverse(Underscore const& _) +{ + return _; +} + +// +// Build the left-inverse of a layout +// @pre is_static +// @pre not has_int0 // @a layout has no 0-strides (is injective) +// @result A layout @a result such that +// @a result(@a layout(i)) == i for all i < size(@a layout) +// @result A layout @a result such that +// composition(@a result, @a layout) is identical to make_layout(shape(layout)) +// + +template +CUTE_HOST_DEVICE constexpr +auto +left_inverse(Layout const& layout) +{ + return right_inverse(make_layout(layout, complement(layout))); +} + +CUTE_HOST_DEVICE constexpr +auto +left_inverse(Underscore const& _) +{ + return _; +} + +// +// Max Common Vector +// + +/* Return Int such that N is the maximum number of continguous elements + * that logically correspond in the layouts of @a a and @a b. This is, + * the number of elements that could reasonably be "vectorized" in the layouts. + * + * @returns Int with N >= 1 + * @post For all 0 <= n < N, a(b[n]) == n (NOTE: Problems with negative strides/coords in this post-condition) + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(Layout const& a, Layout const& b) +{ + if constexpr (is_static>::value && + is_static>::value) + { + auto result = coalesce(composition(a, right_inverse(b))); + + if constexpr (is_constant<1, decltype(stride<0>(result))>::value) { + return shape<0>(result); + } else { + return Int<1>{}; + } + } else { + // Dynamic case NOTE: could weaken if we assume dynamic strides are large and multiples of the vector + return Int<1>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Zip +// + +template +CUTE_HOST_DEVICE constexpr +auto +zip(Layout const& layout) +{ + return make_layout(zip(layout.shape()), + zip(layout.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zip(Layout const& layoutA, + Layout const& layoutB) +{ + return make_layout(zip(layoutA.shape(), layoutB.shape()), + zip(layoutA.stride(), layoutB.stride())); +} + +// +// Tile unzip +// Logical product and logical divide (on layouts) produce rank-2 results by design. +// Follow the profile of @a tile and zip the rank-2 modes located at the terminals into +// their own mode. +// + +template +CUTE_HOST_DEVICE constexpr +auto +tile_unzip(Layout const& layout, + IntTuple const& tile) +{ + return make_layout(zip2_by(layout.shape(), tile), + zip2_by(layout.stride(), tile)); +} + +// +// Logical divide +// + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(Layout const& layout, + Layout const& tile) +{ + //CUTE_STATIC_ASSERT_V(size(layout) % size(tile) == Int<0>{}, + // "Tiling does not evenly divide the block"); + // NOTE: With tiles that have stride-0, this doesn't have to be true + + return composition(layout, make_layout(tile, complement(tile, size(layout)))); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(Layout const& layout, + IntTuple const& tile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank, "logical_divide: Too many modes in tile."); + return transform_layout(layout, tile, [](auto const& l, auto const& t) { return logical_divide(l,t); }); + } else if constexpr (is_underscore::value) { + return layout; + } else if constexpr (is_integral::value) { + return logical_divide(layout, make_layout(tile)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Convenience operator +// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// by gathering the tile modes and residuals into a rank-2 result. +// + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(Layout const& layout, + Tile const& tile) +{ + return tile_unzip(logical_divide(layout, tile), tile); +} + +// Same as zipped_divide, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(Layout const& layout, + Tile const& tile) +{ + auto div = zipped_divide(layout, tile); + + auto R = rank<1>(div); + return div(_, repeat(_)); +} + +// +// Logical product +// + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(Layout const& layout, + Layout const& tile) +{ + return make_layout(layout, composition(complement(layout, size(layout)*cosize(tile)), tile)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(Layout const& layout, + IntTuple const& tile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return transform_layout(layout, tile, [](auto const& l, auto const& t) { return logical_product(l,t); }); + } else if constexpr (is_underscore::value) { + return layout; + } else if constexpr (is_integral::value) { + return logical_product(layout, make_layout(tile)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Convenience operator +// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// by gathering the block modes and products into a rank-2 result. +// + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_product(Layout const& layout, + Tile const& tile) +{ + return tile_unzip(logical_product(layout, tile), tile); +} + +// Same as zipped_product, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +tiled_product(Layout const& layout, + Tile const& tile) +{ + auto div = zipped_product(layout, tile); + + auto R = rank(tile); + return div(_, repeat(_)); +} + +// Attempts to reproduce layout "block" over layout "layout" +// That is, think of every element of "layout" as a "block" +// and return the layout of the resulting structure +template +CUTE_HOST_DEVICE constexpr +auto +blocked_product(Layout const& block, + Layout const& layout) +{ + constexpr int R = cute::max(rank_v, rank_v); + auto padded_block = append(block); + auto padded_layout = append(layout); + + auto result = logical_product(padded_block, padded_layout); + + return coalesce(zip(get<0>(result), get<1>(result)), repeat(Int<1>{})); +} + +template +CUTE_HOST_DEVICE constexpr +auto +raked_product(Layout const& block, + Layout const& layout) +{ + constexpr int R = cute::max(rank_v, rank_v); + auto padded_block = append(block); + auto padded_layout = append(layout); + + auto result = logical_product(padded_block, padded_layout); + + return coalesce(zip(get<1>(result), get<0>(result)), repeat(Int<1>{})); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_to_shape(Layout const& layout, + TrgShape const& trg_shape, + ModeOrder const& ord_shape = {}) +{ + CUTE_STATIC_ASSERT_V(rank(layout) <= rank(trg_shape), "Rank of layout must be <= rank of target shape."); + constexpr int R = rank_v; + + auto padded_layout = append(layout); + + auto layout_shape = product_each(padded_layout.shape()); + auto target_shape = product_each(trg_shape); + + // Assert proper division + CUTE_STATIC_ASSERT_V(sum(transform(target_shape, layout_shape, modulus{})) == Int<0>{}, + "Layout shape does not divide the target shape."); + + auto product_shape = shape_div(target_shape, layout_shape); + + return coalesce(blocked_product(padded_layout, make_ordered_layout(product_shape, ord_shape)), product_shape); +} + +// +// Upcast +// For stride-1 mode, divide size by N. Divide all other strides by N. +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { // tuple stride + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); + } else if constexpr (is_constant<0, Stride>::value) { // static-0 stride + return Layout{shape,stride}; + } else if constexpr (is_static::value) { // static stride + return make_layout(shape_div(shape, shape_div(Int{}, abs(stride))), + shape_div(stride, Int{})); + } else { // dynamic stride + // assume dynamic strides are larger than N and divisible + // assert(stride % N == 0); + return make_layout(shape, safe_div(stride, Int{})); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Layout const& layout) +{ + return upcast(layout.shape(), layout.stride()); +} + +// +// Downcast +// For stride-1 mode, multiply size by N. Multiply all other strides by N. +// + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return downcast(s,d); }); + } else if constexpr (is_constant<1, Stride>::value || is_constant<-1, Stride>::value) { + return make_layout(shape * Int{}, stride); + } else { + return make_layout(shape, stride * Int{}); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(Layout const& layout) +{ + CUTE_STATIC_ASSERT(has_int1::value, "Downcast requires adjacent elements"); + return downcast(layout.shape(), layout.stride()); +} + +// +// Recast +// + +template +CUTE_HOST_DEVICE constexpr +auto +recast(Layout const& layout) +{ + if constexpr (sizeof(NewType) == sizeof(OldType)) { + return layout; + } else if constexpr (sizeof(NewType) > sizeof(OldType)) { + static_assert(sizeof(NewType) % sizeof(OldType) == 0, "NewType must be a multiple of OldType"); + return upcast(layout); + } else if constexpr (sizeof(NewType) < sizeof(OldType)) { + static_assert(sizeof(OldType) % sizeof(NewType) == 0, "NewType must be a divisor of OldType"); + return downcast(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(Layout const& layout) +{ + print(layout.shape()); print(":"); print(layout.stride()); +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, Layout const& layout) +{ + return os << shape(layout) << ":" << stride(layout); +} + +// Generic 2D Layout to console table +template +CUTE_HOST_DEVICE +void +print_layout(Layout const& layout) // (m,n) -> idx +{ + CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); + + int idx_width = num_digits(cosize(layout)) + 2; + const char* delim = "+-----------------------"; + + print(layout); print("\n"); + + // Column indices + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); } + printf("\n"); + + // Print out A m-by-n + for (int m = 0; m < size<0>(layout); ++m) { + // Header + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } + printf("+\n"); + // Values + printf("%2d ", m); // Row indices + for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); } + printf("|\n"); + } + // Footer + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } + printf("+\n"); +} + +// Generic ThrVal 2D Layout to console table +template +CUTE_HOST_DEVICE +void +print_layout(Layout const& layout, ThrID const& thrid) // (m,n) -> (tid,vid) and tid -> thr_idx +{ + CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); + + print(layout); print("\n"); + print(thrid); print("\n"); + + // Print out m-by-n + for (int m = 0; m < size<0>(layout); ++m) { + // Header + for (int n = 0; n < size<1>(layout); ++n) printf("+------"); + printf("+\n"); + // Values + for (int n = 0; n < size<1>(layout); ++n) printf("|%03d-%02d", int(thrid(layout(m,n) % size(thrid))), int(layout(m,n) / size(thrid))); + printf("|\n"); + } + // Footer + for (int n = 0; n < size<1>(layout); ++n) printf("+------"); + printf("+\n"); +} + +// Generic 2D Layout to Latex printer -- B&W 8-value color coding +template +CUTE_HOST_DEVICE +void +print_latex(Layout const& layout) // (m,n) -> idx +{ + CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); + + char const* latex_header = + "\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center,font=\\Large}]\n\n"; + char const* latex_footer = + "\\end{tikzpicture}\n" + "\\end{document}\n"; + + char const* color_map[8] = {"black!00", + "black!40", + "black!20", + "black!60", + "black!10", + "black!50", + "black!30", + "black!70"}; + + // Header + printf("%% Layout: "); print(layout); printf("\n"); + + printf(latex_header); + + // Layout + for (int i = 0; i < size<0>(layout); ++i) { + for (int j = 0; j < size<1>(layout); ++j) { + int idx = layout(i,j); + + printf("\\node[box,fill=%s] at (%d,%d) {%d};\n", + color_map[idx % 8], + i, j, + idx); + } + } + + // Labels + for (int i = 0, j = -1; i < size<0>(layout); ++i) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); + } + for (int j = 0, i = -1; j < size<1>(layout); ++j) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); + } + + // Footer + printf(latex_footer); +} + +// Generic ThrVal 2D Layout to Latex TIKZ -- 8-value color coded by thread +template +CUTE_HOST_DEVICE +void +print_latex(Layout const& layout, ThrID const& thr) // (m,n) -> (tid,vid) and tid -> thr_idx +{ + CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); + + char const* latex_header = + "\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; + char const* latex_footer = + "\\end{tikzpicture}\n" + "\\end{document}\n"; + + char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", + "{rgb,255:red,175;green,255;blue,175}", + "{rgb,255:red,255;green,255;blue,175}", + "{rgb,255:red,255;green,175;blue,175}", + "{rgb,255:red,210;green,210;blue,255}", + "{rgb,255:red,210;green,255;blue,210}", + "{rgb,255:red,255;green,255;blue,210}", + "{rgb,255:red,255;green,210;blue,210}"}; + + // Header + printf("%% layout: "); print(layout); printf("\n"); + printf("%% thrid: "); print(thr); printf("\n\n"); + + printf(latex_header); + + // Layout + for (int i = 0; i < size<0>(layout); ++i) { + for (int j = 0; j < size<1>(layout); ++j) { + int thrid = layout(i,j) % size(thr); + int val_idx = layout(i,j) / size(thr); + int thr_idx = thr(thrid); + + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[thr_idx % 8], + i, j, + thr_idx, val_idx); + } + } + + // Labels + for (int i = 0, j = -1; i < size<0>(layout); ++i) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); + } + for (int j = 0, i = -1; j < size<1>(layout); ++j) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); + } + + // Footer + printf(latex_footer); +} + +} // end namespace cute + +// +// Extended Layouts +// + +#include diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp new file mode 100644 index 0000000000..33471e4f16 --- /dev/null +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -0,0 +1,388 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include + +namespace cute +{ + +template +struct ArithmeticTuple : tuple +{ + template + CUTE_HOST_DEVICE constexpr + ArithmeticTuple(ArithmeticTuple const& u) + : tuple(static_cast const&>(u)) {} + + template + CUTE_HOST_DEVICE constexpr + ArithmeticTuple(tuple const& u) + : tuple(u) {} + + template + CUTE_HOST_DEVICE constexpr + ArithmeticTuple(U const&... u) + : tuple(u...) {} +}; + +template +struct is_tuple> : true_type {}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_arithmetic_tuple(T const&... t) { + return ArithmeticTuple(t...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(tuple const& t) { + return ArithmeticTuple(t); +} + +// +// Numeric operators +// + +// Addition +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, ArithmeticTuple const& u) { + constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); + return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, tuple const& u) { + constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); + return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(tuple const& t, ArithmeticTuple const& u) { + constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); + return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +// +// Special cases +// + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(constant, ArithmeticTuple const& u) { + return u; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, constant) { + return t; +} + +// +// ArithmeticTupleIterator +// + +template +struct ArithmeticTupleIterator +{ + ArithTuple coord_; + + CUTE_HOST_DEVICE constexpr + ArithmeticTupleIterator() : coord_() {} + CUTE_HOST_DEVICE constexpr + ArithmeticTupleIterator(ArithTuple const& coord) : coord_(coord) {} + + CUTE_HOST_DEVICE constexpr + ArithTuple const& operator*() const { return coord_; } + + template + CUTE_HOST_DEVICE constexpr + auto operator+(Coord const& c) const { + return ArithmeticTupleIterator(coord_ + c); + } + + template + CUTE_HOST_DEVICE constexpr + auto operator[](Coord const& c) const { return *(*this + c); } +}; + +template +CUTE_HOST_DEVICE void print(ArithmeticTupleIterator const& iter) { + printf("ArithTuple"); print(iter.coord_); +} + +// +// ArithmeticTuple "basis" elements +// + +// Abstract value: +// A ScaledBasis is a (at least) rank-N0 ArithmeticTuple: +// (_0,_0,...,T,_0,...) + +template +struct ScaledBasis : private tuple +{ + CUTE_HOST_DEVICE constexpr + ScaledBasis(T const& t = {}) : tuple(t) {} + + CUTE_HOST_DEVICE constexpr + decltype(auto) value() { return get<0>(static_cast &>(*this)); } + CUTE_HOST_DEVICE constexpr + decltype(auto) value() const { return get<0>(static_cast const&>(*this)); } + + CUTE_HOST_DEVICE static constexpr + auto mode() { return Int{}; } +}; + +template +struct is_scaled_basis : false_type {}; +template +struct is_scaled_basis> : true_type {}; + +template +struct is_integral> : true_type {}; + +template +CUTE_HOST_DEVICE constexpr auto +basis_value(T const& e) { + return e; +} + +template +CUTE_HOST_DEVICE constexpr auto +basis_value(ScaledBasis const& e) { + return basis_value(e.value()); +} + +namespace detail { + +template +struct Basis; + +template <> +struct Basis<> { + using type = Int<1>; +}; + +template +struct Basis { + using type = ScaledBasis::type, N>; +}; + +} // end namespace detail + +template +using E = typename detail::Basis::type; + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(T const& t, seq, seq) { + return make_arithmetic_tuple((void(I),Int<0>{})..., t, (void(J),Int<0>{})...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ArithmeticTuple const& t, seq, seq) { + return make_arithmetic_tuple(get(t)..., (void(J),Int<0>{})...); +} + +} // end namespace detail + +// Turn a ScaledBases into a rank-M ArithmeticTuple +// with N prefix 0s: (_0,_0,...N...,_0,T,_0,...,_0,_0) +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ScaledBasis const& t) { + static_assert(M > N, "Mismatched ranks"); + return detail::as_arithmetic_tuple(t.value(), make_seq{}, make_seq{}); +} + +// Turn an ArithmeticTuple into a rank-M ArithmeticTuple +// with postfix 0s: (t0,t1,t2,...,_0,...,_0,_0) +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ArithmeticTuple const& t) { + static_assert(M >= sizeof...(T), "Mismatched ranks"); + return detail::as_arithmetic_tuple(t, make_seq{}, make_seq{}); +} + +// Return... +template +CUTE_HOST_DEVICE constexpr +auto +make_basis_like(Shape const& shape) +{ + if constexpr (is_integral::value) { + return Int<1>{}; + } else { + // Generate bases for each rank of shape + return transform(tuple_seq{}, [&](auto I) { + // Generate bases for each rank of shape_i and add an i on front + constexpr int i = decltype(I)::value; // NOTE: nvcc workaround + return transform_leaf(make_basis_like(get(shape)), [&](auto e) { return ScaledBasis{}; }); + }); + } + + CUTE_GCC_UNREACHABLE; +} + +// Equality +template +CUTE_HOST_DEVICE constexpr +auto +operator==(ScaledBasis, Int) { + return false_type{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator==(Int, ScaledBasis) { + return false_type{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator==(ScaledBasis const& t, ScaledBasis const& u) { + return bool_constant{} && t.value() == u.value(); +} + +// Multiplication +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +operator*(A const& a, ScaledBasis const& e) { + return ScaledBasis{a*e.value()}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +operator*(ScaledBasis const& e, B const& b) { + return ScaledBasis{e.value()*b}; +} + +// Addition +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, ArithmeticTuple const& u) { + constexpr int R = cute::max(N+1, int(sizeof...(U))); + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, ScaledBasis const& u) { + constexpr int R = cute::max(int(sizeof...(T)), M+1); + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, ScaledBasis const& u) { + constexpr int R = cute::max(N+1,M+1); + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(constant, ScaledBasis const& u) { + return u; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, constant) { + return t; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(ScaledBasis const& e) { + printf("%d:", N); print(e.value()); +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) { + return os << N << ":" << e.value(); +} + +} // end namespace cute + + +namespace std +{ + +template +struct tuple_size> + : std::integral_constant +{}; + +template +struct tuple_element> + : std::tuple_element> +{}; + +} // end namespace std diff --git a/include/cute/numeric/bfloat.hpp b/include/cute/numeric/bfloat.hpp new file mode 100644 index 0000000000..94f64ab572 --- /dev/null +++ b/include/cute/numeric/bfloat.hpp @@ -0,0 +1,51 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute { + +using cutlass::bfloat16_t; + +// +// Display utilities +// + +CUTE_HOST std::ostream& operator<<(std::ostream& os, bfloat16_t const& v) +{ + return os << float(v); +} + +} // end namespace cute diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp new file mode 100644 index 0000000000..3790ebd3b1 --- /dev/null +++ b/include/cute/numeric/complex.hpp @@ -0,0 +1,163 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +//#if defined(__CUDA_ARCH__) +//# include +//#else +//# include +//#endif + +// With CUDA 11.4, builds show spurious "-Wconversion" warnings +// on line 656 of thrust/detail/type_traits.h. +// These pragmas suppress the warnings. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#include +#pragma GCC diagnostic pop + +#include + +namespace cute +{ + +//#if defined(__CUDA_ARCH__) +//template +//using complex = cuda::std::complex; +//#else +//template +//using complex = std::complex; +//#endif + +//template +//using complex = thrust::complex; + +using thrust::complex; + +template +CUTE_HOST_DEVICE +T real(complex const& z) { + return z.real(); +} + +template +CUTE_HOST_DEVICE +T imag(complex const& z) { + return z.imag(); +} + +template +CUTE_HOST_DEVICE +complex conj(complex const& z) { + return complex(real(z), -imag(z)); +} + +// cute::conj forwards scalars +template +CUTE_HOST_DEVICE +T conj(T z) { + return z; +} + +//CUTE_HOST_DEVICE constexpr +//float conj(float z) { return z; } +//CUTE_HOST_DEVICE constexpr +//double conj(double z) { return z; } + +/// Fused multiply-add for complex numbers +template +CUTE_HOST_DEVICE constexpr +void +fma(complex & d, + complex const& a, + complex const& b, + complex const& c) +{ + d.real(c.real() + a.real() * b.real()); + d.imag(c.imag() + a.real() * b.imag()); + d.real(d.real() - a.imag() * b.imag()); + d.imag(d.imag() + a.imag() * b.real()); +} + +/// Fused multiply-add for triplets +template +CUTE_HOST_DEVICE constexpr +void +fma(complex const& a, + complex const& b, + complex & c) +{ + return fma(c, a, b, c); +} + +/// Used to determine the real-valued underlying type of a numeric type T +template +struct RealType { + using Type = T; +}; + +/// Partial specialization for complex-valued type +template +struct RealType> { + using Type = T; +}; + +////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct is_complex { + static bool const value = false; +}; + +template +struct is_complex> { + static bool const value = true; +}; + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Display utilities + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, complex const& z) +{ + T _r = z.real(); + T _i = z.imag(); + + if (bool(_i)) { + return os << _r << "+i" << _i; + } else { + return os << _r; + } +} + +} // end namespace cute diff --git a/include/cute/numeric/float8.hpp b/include/cute/numeric/float8.hpp new file mode 100644 index 0000000000..3fa471db34 --- /dev/null +++ b/include/cute/numeric/float8.hpp @@ -0,0 +1,43 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute { + +using cutlass::float_e4m3_t; +using cutlass::float_e5m2_t; + +} // end namespace cute diff --git a/include/cute/numeric/half.hpp b/include/cute/numeric/half.hpp new file mode 100644 index 0000000000..704ba28d55 --- /dev/null +++ b/include/cute/numeric/half.hpp @@ -0,0 +1,41 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +namespace cute { + +using cutlass::half_t; + +} // end namespace cute diff --git a/include/cute/numeric/int.hpp b/include/cute/numeric/int.hpp new file mode 100644 index 0000000000..a08297f209 --- /dev/null +++ b/include/cute/numeric/int.hpp @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include +#include + +namespace cute +{ + +// +// Signed integers +// + +using int8_t = std::int8_t; +using int16_t = std::int16_t; +using int32_t = std::int32_t; +using int64_t = std::int64_t; + +template struct int_bit; +template <> struct int_bit< 2> { using type = cute::int2b_t; }; +template <> struct int_bit< 4> { using type = cute::int4b_t; }; +template <> struct int_bit< 8> { using type = int8_t; }; +template <> struct int_bit< 16> { using type = int16_t; }; +template <> struct int_bit< 32> { using type = int32_t; }; +template <> struct int_bit< 64> { using type = int64_t; }; + +template +using int_bit_t = typename int_bit::type; + +template +using int_byte = int_bit<8*N>; + +template +using int_byte_t = typename int_byte::type; + +// +// Unsigned integers +// + +using uint8_t = std::uint8_t; +using uint16_t = std::uint16_t; +using uint32_t = std::uint32_t; +using uint64_t = std::uint64_t; + +template struct uint_bit; +template <> struct uint_bit< 1> { using type = cute::uint1b_t; }; +template <> struct uint_bit< 2> { using type = cute::uint2b_t; }; +template <> struct uint_bit< 4> { using type = cute::uint4b_t; }; +template <> struct uint_bit< 8> { using type = uint8_t; }; +template <> struct uint_bit< 16> { using type = uint16_t; }; +template <> struct uint_bit< 32> { using type = uint32_t; }; +template <> struct uint_bit< 64> { using type = uint64_t; }; +template <> struct uint_bit<128> { using type = cute::uint128_t; }; + +template +using uint_bit_t = typename uint_bit::type; + +template +using uint_byte = uint_bit<8*N>; + +template +using uint_byte_t = typename uint_byte::type; + +// +// sizeof_bytes +// + +template +struct sizeof_bytes { + static constexpr std::size_t value = sizeof(T); +}; +template +static constexpr int sizeof_bytes_v = sizeof_bytes::value; + +// +// sizeof_bits +// + +template +struct sizeof_bits { + static constexpr std::size_t value = sizeof(T) * 8; +}; +template <> +struct sizeof_bits { + static constexpr std::size_t value = 1; +}; +template +struct sizeof_bits> { + static constexpr std::size_t value = Bits; +}; +template +static constexpr int sizeof_bits_v = sizeof_bits::value; + +} // namespace cute diff --git a/include/cute/numeric/integer_sequence.hpp b/include/cute/numeric/integer_sequence.hpp new file mode 100644 index 0000000000..73a83f76a9 --- /dev/null +++ b/include/cute/numeric/integer_sequence.hpp @@ -0,0 +1,139 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // std::integer_sequence + +#include + +namespace cute +{ + +using std::integer_sequence; +using std::make_integer_sequence; + +namespace detail { + +template +struct make_integer_range_impl; + +template +struct make_integer_range_impl, Begin> { + using type = integer_sequence; +}; + +} // end namespace detail + +template +using make_integer_range = typename detail::make_integer_range_impl< + T, + make_integer_sequence 0) ? (End-Begin) : 0>, + Begin>::type; + +// +// Common aliases +// + +// int_sequence + +template +using int_sequence = integer_sequence; + +template +using make_int_sequence = make_integer_sequence; + +template +using make_int_range = make_integer_range; + +// index_sequence + +template +using index_sequence = integer_sequence; + +template +using make_index_sequence = make_integer_sequence; + +template +using make_index_range = make_integer_range; + +// +// Shortcuts +// + +template +using seq = int_sequence; + +template +using make_seq = make_int_sequence; + +template +using make_range = make_int_range; + +template +using tuple_seq = make_seq>::value>; + +} // end namespace cute + + +// +// Specialize tuple-related functionality for cute::integer_sequence +// + +#include +#include + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +std::tuple_element_t> +get(integer_sequence) { + static_assert(I < sizeof...(Ints), "Index out of range"); + return {}; +} + +} // end namespace cute + +namespace std +{ + +template +struct tuple_size> + : std::integral_constant +{}; + +template +struct tuple_element> + : std::tuple_element...>> +{}; + +} // end namespace std diff --git a/include/cute/numeric/integer_subbyte.hpp b/include/cute/numeric/integer_subbyte.hpp new file mode 100644 index 0000000000..3d24a95293 --- /dev/null +++ b/include/cute/numeric/integer_subbyte.hpp @@ -0,0 +1,233 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include +#include + +namespace cute { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct integer_subbyte +{ + /// Storage type + using Storage = uint8_t; + + /// Number of bits + static_assert(Bits <= 8*sizeof(Storage), "Require a subbyte of bits in integer_subbyte"); + + /// External type + using xint_t = typename std::conditional::type; + + /// Bitmask for truncation from larger integers + static constexpr Storage bits_mask_ = Storage((1 << Bits) - 1); + /// Bitmask for the sign bit + static constexpr Storage sign_mask_ = Storage((Signed ? 1 : 0) << (Bits - 1)); + + // + // Data members + // + + Storage storage; + + // + // Methods + // + + /// No operation + CUTE_HOST_DEVICE constexpr + integer_subbyte() {} + + /// Conversion from integer type + CUTE_HOST_DEVICE constexpr + integer_subbyte(int value) // NOTE: Sign extension? + : storage(reinterpret_cast(value) & bits_mask_) {} + + CUTE_HOST_DEVICE constexpr + integer_subbyte(unsigned value) + : storage(reinterpret_cast(value) & bits_mask_) {} + + /// Convert to int or unsigned + CUTE_HOST_DEVICE constexpr + operator xint_t() const { + if (sign_mask_ & storage) { // Sign extend + return xint_t(storage) | ~xint_t(bits_mask_); + } else { + return xint_t(storage); + } + } + + /// Equality + CUTE_HOST_DEVICE constexpr + bool operator==(integer_subbyte const& rhs) const { + return storage == rhs.storage; + } + + /// Inequality + CUTE_HOST_DEVICE constexpr + bool operator!=(integer_subbyte const& rhs) const { + return storage != rhs.storage; + } + + /// Less than or equal + CUTE_HOST_DEVICE constexpr + bool operator<=(integer_subbyte const& rhs) const { + if (sign_mask_ & storage) { + return !(rhs.storage < storage); + } else { + return storage < rhs.storage; + } + } + + /// Less than + CUTE_HOST_DEVICE constexpr + bool operator<(integer_subbyte const& rhs) const { + if (sign_mask_ & storage) { + return !(rhs.storage <= storage); + } else { + return storage < rhs.storage; + } + } + + /// Greater than or equal + CUTE_HOST_DEVICE constexpr + bool operator>=(integer_subbyte const& rhs) const { + return !(*this < rhs); + } + + /// Greater than + CUTE_HOST_DEVICE constexpr + bool operator>(integer_subbyte const& rhs) const { + return !(*this <= rhs); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// 1-bit unsigned integer type +using uint1b_t = integer_subbyte<1, false>; + +/// 2-bit integer type +using int2b_t = integer_subbyte<2, true>; + +/// 2-bit unsigned integer type +using uint2b_t = integer_subbyte<2, false>; + +/// 4-bit integer type +using int4b_t = integer_subbyte<4, true>; + +/// 4-bit unsigned integer type +using uint4b_t = integer_subbyte<4, false>; + +/// 1-bit binary type +using bin1_t = bool; + +} // namespace cute + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if !defined(__CUDACC_RTC__) + +#include + +namespace std { + +template <> +struct numeric_limits { + CUTE_HOST_DEVICE static constexpr + cute::uint1b_t const lowest() noexcept { return 0; } + CUTE_HOST_DEVICE static constexpr + cute::uint1b_t const min() noexcept { return 0; } + CUTE_HOST_DEVICE static constexpr + cute::uint1b_t const max() noexcept { return 1; } + static constexpr bool is_integer = true; + static constexpr bool is_signed = false; +}; + +template <> +struct numeric_limits { + CUTE_HOST_DEVICE static constexpr + cute::int2b_t lowest() noexcept { return -2; } + CUTE_HOST_DEVICE static constexpr + cute::int2b_t min() noexcept { return -2; } + CUTE_HOST_DEVICE static constexpr + cute::int2b_t max() noexcept { return 1; } + static constexpr bool is_integer = true; + static constexpr bool is_signed = true; +}; + +template <> +struct numeric_limits { + CUTE_HOST_DEVICE static constexpr + cute::uint2b_t const lowest() noexcept { return 0; } + CUTE_HOST_DEVICE static constexpr + cute::uint2b_t const min() noexcept { return 0; } + CUTE_HOST_DEVICE static constexpr + cute::uint2b_t const max() noexcept { return 3; } + static constexpr bool is_integer = true; + static constexpr bool is_signed = false; +}; + +template <> +struct numeric_limits { + CUTE_HOST_DEVICE static constexpr + cute::int4b_t lowest() noexcept { return -8; } + CUTE_HOST_DEVICE static constexpr + cute::int4b_t min() noexcept { return -8; } + CUTE_HOST_DEVICE static constexpr + cute::int4b_t max() noexcept { return 7; } + static constexpr bool is_integer = true; + static constexpr bool is_signed = true; +}; + +template <> +struct numeric_limits { + CUTE_HOST_DEVICE static constexpr + cute::uint4b_t const lowest() noexcept { return 0; } + CUTE_HOST_DEVICE static constexpr + cute::uint4b_t const min() noexcept { return 0; } + CUTE_HOST_DEVICE static constexpr + cute::uint4b_t const max() noexcept { return 15; } + static constexpr bool is_integer = true; + static constexpr bool is_signed = false; +}; + +} // namespace std + +#endif diff --git a/include/cute/numeric/integral_constant.hpp b/include/cute/numeric/integral_constant.hpp new file mode 100644 index 0000000000..106763df58 --- /dev/null +++ b/include/cute/numeric/integral_constant.hpp @@ -0,0 +1,414 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template +struct constant : std::integral_constant { + static constexpr T value = v; + using value_type = T; + using type = constant; + CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } + CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } +}; + +template +using integral_constant = constant; + +template +using bool_constant = constant; + +using true_type = bool_constant; +using false_type = bool_constant; + +// +// Traits +// + +// Use std::is_integral to match built-in integral types (int, int64_t, unsigned, etc) +// Use cute::is_integral to match both built-in integral types AND constant + +template +struct is_integral : bool_constant::value> {}; +template +struct is_integral> : true_type {}; + +// is_static detects if an (abstract) value is defined completely by it's type (no members) + +template +struct is_static : bool_constant::value> {}; + +// is_constant detects if a type is a constant and if v is equal to a value + +template +struct is_constant : false_type {}; +template +struct is_constant > : bool_constant {}; +template +struct is_constant const > : bool_constant {}; +template +struct is_constant const&> : bool_constant {}; +template +struct is_constant &> : bool_constant {}; +template +struct is_constant &&> : bool_constant {}; + +// +// Specializations +// + +template +using Int = constant; + +using _m32 = Int<-32>; +using _m24 = Int<-24>; +using _m16 = Int<-16>; +using _m12 = Int<-12>; +using _m10 = Int<-10>; +using _m9 = Int<-9>; +using _m8 = Int<-8>; +using _m7 = Int<-7>; +using _m6 = Int<-6>; +using _m5 = Int<-5>; +using _m4 = Int<-4>; +using _m3 = Int<-3>; +using _m2 = Int<-2>; +using _m1 = Int<-1>; +using _0 = Int<0>; +using _1 = Int<1>; +using _2 = Int<2>; +using _3 = Int<3>; +using _4 = Int<4>; +using _5 = Int<5>; +using _6 = Int<6>; +using _7 = Int<7>; +using _8 = Int<8>; +using _9 = Int<9>; +using _10 = Int<10>; +using _12 = Int<12>; +using _16 = Int<16>; +using _24 = Int<24>; +using _32 = Int<32>; +using _64 = Int<64>; +using _96 = Int<96>; +using _128 = Int<128>; +using _192 = Int<192>; +using _256 = Int<256>; +using _512 = Int<512>; +using _1024 = Int<1024>; +using _2048 = Int<2048>; +using _4096 = Int<4096>; +using _8192 = Int<8192>; + +/***************/ +/** Operators **/ +/***************/ + +#define CUTE_LEFT_UNARY_OP(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + constant \ + operator OP (constant) { \ + return {}; \ + } +#define CUTE_RIGHT_UNARY_OP(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + constant \ + operator OP (constant) { \ + return {}; \ + } + +#define CUTE_BINARY_OP(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + constant \ + operator OP (constant, constant) { \ + return {}; \ + } + +CUTE_LEFT_UNARY_OP(+); +CUTE_LEFT_UNARY_OP(-); +CUTE_LEFT_UNARY_OP(~); +CUTE_LEFT_UNARY_OP(!); +CUTE_LEFT_UNARY_OP(*); + +CUTE_BINARY_OP( +); +CUTE_BINARY_OP( -); +CUTE_BINARY_OP( *); +CUTE_BINARY_OP( /); +CUTE_BINARY_OP( %); +CUTE_BINARY_OP( &); +CUTE_BINARY_OP( |); +CUTE_BINARY_OP( ^); +CUTE_BINARY_OP(<<); +CUTE_BINARY_OP(>>); + +CUTE_BINARY_OP(&&); +CUTE_BINARY_OP(||); + +CUTE_BINARY_OP(==); +CUTE_BINARY_OP(!=); +CUTE_BINARY_OP( >); +CUTE_BINARY_OP( <); +CUTE_BINARY_OP(>=); +CUTE_BINARY_OP(<=); + +#undef CUTE_BINARY_OP +#undef CUTE_LEFT_UNARY_OP +#undef CUTE_RIGHT_UNARY_OP + +// +// Mixed static-dynamic special cases +// + +template ::value)> +CUTE_HOST_DEVICE constexpr +constant +operator*(constant, U) { + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +constant +operator*(U, constant) { + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +constant +operator/(constant, U) { + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +constant +operator%(U, constant) { + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +constant +operator%(U, constant) { + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +constant +operator%(constant, U) { + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +constant +operator&(constant, U) { + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +constant +operator&(U, constant) { + return {}; +} + +template ::value && !bool(t))> +CUTE_HOST_DEVICE constexpr +constant +operator&&(constant, U) { + return {}; +} + +template ::value && !bool(t))> +CUTE_HOST_DEVICE constexpr +constant +operator&&(U, constant) { + return {}; +} + +template ::value && bool(t))> +CUTE_HOST_DEVICE constexpr +constant +operator||(constant, U) { + return {}; +} + +template ::value && bool(t))> +CUTE_HOST_DEVICE constexpr +constant +operator||(U, constant) { + return {}; +} + +// +// Named functions from math.hpp +// + +#define CUTE_NAMED_UNARY_FN(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + constant \ + OP (constant) { \ + return {}; \ + } + +#define CUTE_NAMED_BINARY_FN(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + constant \ + OP (constant, constant) { \ + return {}; \ + } \ + \ + template ::value)> \ + CUTE_HOST_DEVICE constexpr \ + auto \ + OP (constant, U u) { \ + return OP(t,u); \ + } \ + \ + template ::value)> \ + CUTE_HOST_DEVICE constexpr \ + auto \ + OP (T t, constant) { \ + return OP(t,u); \ + } + +CUTE_NAMED_UNARY_FN(abs); +CUTE_NAMED_UNARY_FN(signum); +CUTE_NAMED_UNARY_FN(has_single_bit); + +CUTE_NAMED_BINARY_FN(max); +CUTE_NAMED_BINARY_FN(min); +CUTE_NAMED_BINARY_FN(shiftl); +CUTE_NAMED_BINARY_FN(shiftr); +CUTE_NAMED_BINARY_FN(gcd); +CUTE_NAMED_BINARY_FN(lcm); + +#undef CUTE_NAMED_UNARY_FN +#undef CUTE_NAMED_BINARY_FN + +// +// Other functions +// + +template +CUTE_HOST_DEVICE constexpr +constant +safe_div(constant, constant) { + static_assert(t % u == 0, "Static safe_div requires t % u == 0"); + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +safe_div(constant, U u) { + return t / u; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +safe_div(T t, constant) { + return t / u; +} + +// cute::true_type prefers standard conversion to std::true_type +// over user-defined conversion to bool +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +conditional_return(std::true_type, TrueType&& t, FalseType&&) { + return static_cast(t); +} + +// cute::false_type prefers standard conversion to std::false_type +// over user-defined conversion to bool +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +conditional_return(std::false_type, TrueType&&, FalseType&& f) { + return static_cast(f); +} + +// TrueType and FalseType must have a common type +template +CUTE_HOST_DEVICE constexpr +auto +conditional_return(bool b, TrueType const& t, FalseType const& f) { + return b ? t : f; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(integral_constant const&) { + printf("_%d", N); +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, integral_constant const&) { + return os << "_" << N; +} + +} // end namespace cute diff --git a/include/cute/numeric/math.hpp b/include/cute/numeric/math.hpp new file mode 100644 index 0000000000..03e8379977 --- /dev/null +++ b/include/cute/numeric/math.hpp @@ -0,0 +1,319 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include + +namespace cute +{ + +// +// Common Operations +// + +template ::value && + std::is_arithmetic::value)> +CUTE_HOST_DEVICE constexpr +auto +max(T const& t, U const& u) { + return t < u ? u : t; +} + +template ::value && + std::is_arithmetic::value)> +CUTE_HOST_DEVICE constexpr +auto +min(T const& t, U const& u) { + return t < u ? t : u; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +abs(T const& t) { + if constexpr (std::is_signed::value) { + return t < T(0) ? -t : t; + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// C++17 operations +// + +// Greatest common divisor of two integers +template ::value && + std::is_integral::value)> +CUTE_HOST_DEVICE constexpr +auto +gcd(T t, U u) { + while (true) { + if (t == 0) { return u; } + u %= t; + if (u == 0) { return t; } + t %= u; + } +} + +// Least common multiple of two integers +template ::value && + std::is_integral::value)> +CUTE_HOST_DEVICE constexpr +auto +lcm(T const& t, U const& u) { + return (t / gcd(t,u)) * u; +} + +// +// C++20 operations +// + +// Checks if a number is an integral power of two +template +CUTE_HOST_DEVICE constexpr +bool +has_single_bit(T x) { + return x != 0 && (x & (x - 1)) == 0; +} + +// Smallest number of bits needed to represent the given value +// bit_width( 0b0000 ) = 0 +// bit_width( 0b0001 ) = 1 +// bit_width( 0b0010 ) = 2 +// bit_width( 0b0011 ) = 2 +// bit_width( 0b0100 ) = 3 +// bit_width( 0b0101 ) = 3 +// bit_width( 0b0110 ) = 3 +// bit_width( 0b0111 ) = 3 +template +CUTE_HOST_DEVICE constexpr +T +bit_width(T x) { + static_assert(std::is_unsigned::value, "Only to be used for unsigned types."); + constexpr int N = (std::numeric_limits::digits == 64 ? 6 : + (std::numeric_limits::digits == 32 ? 5 : + (std::numeric_limits::digits == 16 ? 4 : + (std::numeric_limits::digits == 8 ? 3 : (assert(false),0))))); + T r = 0; + for (int i = N - 1; i >= 0; --i) { + T shift = (x > ((T(1) << (T(1) << i))-1)) << i; + x >>= shift; + r |= shift; + } + return r + (x != 0); +} + +// Smallest integral power of two not less than the given value +// bit_ceil( 0b00000000 ) = 0b00000001 +// bit_ceil( 0b00000001 ) = 0b00000001 +// bit_ceil( 0b00000010 ) = 0b00000010 +// bit_ceil( 0b00000011 ) = 0b00000100 +// bit_ceil( 0b00000100 ) = 0b00000100 +// bit_ceil( 0b00000101 ) = 0b00001000 +// bit_ceil( 0b00000110 ) = 0b00001000 +// bit_ceil( 0b00000111 ) = 0b00001000 +// bit_ceil( 0b00001000 ) = 0b00001000 +// bit_ceil( 0b00001001 ) = 0b00010000 +template +CUTE_HOST_DEVICE constexpr +T +bit_ceil(T x) { + return x == 0 ? T(1) : (T(1) << bit_width(x - 1)); +} + +// Largest integral power of two not greater than the given value +// bit_floor( 0b00000000 ) = 0b00000000 +// bit_floor( 0b00000001 ) = 0b00000001 +// bit_floor( 0b00000010 ) = 0b00000010 +// bit_floor( 0b00000011 ) = 0b00000010 +// bit_floor( 0b00000100 ) = 0b00000100 +// bit_floor( 0b00000101 ) = 0b00000100 +// bit_floor( 0b00000110 ) = 0b00000100 +// bit_floor( 0b00000111 ) = 0b00000100 +// bit_floor( 0b00001000 ) = 0b00001000 +// bit_floor( 0b00001001 ) = 0b00001000 +template +CUTE_HOST_DEVICE constexpr +T +bit_floor(T x) { + return x == 0 ? 0 : (T(1) << (bit_width(x) - 1)); +} + +template +CUTE_HOST_DEVICE constexpr T rotl(T x, int s); +template +CUTE_HOST_DEVICE constexpr T rotr(T x, int s); + +// Computes the result of circular bitwise left-rotation +template +CUTE_HOST_DEVICE constexpr +T +rotl(T x, int s) { + constexpr int N = std::numeric_limits::digits; + return s == 0 ? x : s > 0 ? (x << s) | (x >> (N - s)) : rotr(x, -s); +} + +// Computes the result of circular bitwise right-rotation +template +CUTE_HOST_DEVICE constexpr +T +rotr(T x, int s) { + constexpr int N = std::numeric_limits::digits; + return s == 0 ? x : s > 0 ? (x >> s) | (x << (N - s)) : rotl(x, -s); +} + +// Counts the number of consecutive 0 bits, starting from the most significant bit +// countl_zero( 0b00000000 ) = 8 +// countl_zero( 0b11111111 ) = 0 +// countl_zero( 0b00011100 ) = 3 +template +CUTE_HOST_DEVICE constexpr +T +countl_zero(T x) { + return std::numeric_limits::digits - bit_width(x); +} + +// Counts the number of consecutive 1 bits, starting from the most significant bit +// countl_one( 0b00000000 ) = 0 +// countl_one( 0b11111111 ) = 8 +// countl_one( 0b11100011 ) = 3 +template +CUTE_HOST_DEVICE constexpr +T +countl_one(T x) { + return countl_zero(~x); +} + +// Counts the number of consecutive 0 bits, starting from the least significant bit +// countr_zero( 0b00000000 ) = 8 +// countr_zero( 0b11111111 ) = 0 +// countr_zero( 0b00011100 ) = 2 +template +CUTE_HOST_DEVICE constexpr +T +countr_zero(T x) { + return x == 0 ? std::numeric_limits::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB +} + +// Counts the number of consecutive 1 bits, starting from the least significant bit +// countr_one( 0b00000000 ) = 0 +// countr_one( 0b11111111 ) = 8 +// countr_one( 0b11100011 ) = 2 +template +CUTE_HOST_DEVICE constexpr +T +countr_one(T x) { + return countr_zero(~x); +} + +// Counts the number of 1 bits in an unsigned integer +// popcount( 0b00000000 ) = 0 +// popcount( 0b11111111 ) = 8 +// popcount( 0b00011101 ) = 4 +template +CUTE_HOST_DEVICE constexpr +int +popcount(T x) { + int c = 0; + while (x) { + ++c; + x &= x - 1; // clear the least significant bit set + } + return c; +} + +// +// Custom operations +// + +// Computes the result of bitwise left-shift +template +CUTE_HOST_DEVICE constexpr +T +shiftl(T x, int s) { + return s >= 0 ? (x << s) : (x >> -s); +} + +// Computes the result of bitwise right-shift +template +CUTE_HOST_DEVICE constexpr +T +shiftr(T x, int s) { + return s >= 0 ? (x >> s) : (x << -s); +} + +// Returns 1 if x > 0, -1 if x < 0, and 0 if x is zero. +template ::value)> +CUTE_HOST_DEVICE constexpr +int +signum(T const& x) { + return T(0) < x; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +int +signum(T const& x) { + return (T(0) < x) - (x < T(0)); +} + +// Safe divide +// @pre t % u == 0 +// @result t / u +template ::value && + std::is_integral::value)> +CUTE_HOST_DEVICE constexpr +auto +safe_div(T const& t, U const& u) { + //assert(t % u == 0); + return t / u; +} + +} // namespace cute diff --git a/include/cute/numeric/real.hpp b/include/cute/numeric/real.hpp new file mode 100644 index 0000000000..d85e30405a --- /dev/null +++ b/include/cute/numeric/real.hpp @@ -0,0 +1,56 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +namespace cute +{ + +/// Generic fused multiply-add +template +CUTE_HOST_DEVICE constexpr +void +fma(D& d, A const& a, B const& b, C const& c) +{ + d = a * b + c; +} + +/// Fused multiply-add for triplets +template +CUTE_HOST_DEVICE constexpr +void +fma(A const& a, B const& b, C& c) +{ + return fma(c, a, b, c); +} + +} // end namespace cute diff --git a/include/cute/numeric/tfloat.hpp b/include/cute/numeric/tfloat.hpp new file mode 100644 index 0000000000..bb68b703eb --- /dev/null +++ b/include/cute/numeric/tfloat.hpp @@ -0,0 +1,51 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute { + +using cutlass::tfloat32_t; + +// +// Display utilities +// + +CUTE_HOST std::ostream& operator<<(std::ostream& os, tfloat32_t const& v) +{ + return os << float(v); +} + +} // end namespace cute diff --git a/include/cute/numeric/uint128.hpp b/include/cute/numeric/uint128.hpp new file mode 100644 index 0000000000..fb02441fae --- /dev/null +++ b/include/cute/numeric/uint128.hpp @@ -0,0 +1,259 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#include +#include +#include +#include +#endif + +#include + +/// Optionally enable GCC's built-in type +#if defined(__x86_64) && !defined(__CUDA_ARCH__) +# if defined(__GNUC__) && 0 +# define CUTE_UINT128_NATIVE +# elif defined(_MSC_VER) +# define CUTE_INT128_ARITHMETIC +# include +# endif +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///! Unsigned 128b integer type +struct alignas(16) uint128_t +{ + /// Size of one part of the uint's storage in bits + static constexpr int storage_bits_ = 64; + + struct hilo + { + uint64_t lo; + uint64_t hi; + }; + + // Use a union to store either low and high parts or, if present, a built-in 128b integer type. + union + { + struct hilo hilo_; + +#if defined(CUTE_UINT128_NATIVE) + unsigned __int128 native; +#endif // defined(CUTE_UINT128_NATIVE) + }; + + // + // Methods + // + + /// Default ctor + CUTE_HOST_DEVICE constexpr + uint128_t() : hilo_{0, 0} {} + + /// Constructor from uint64 + CUTE_HOST_DEVICE constexpr + uint128_t(uint64_t lo_) : hilo_{lo_, 0} {} + + /// Constructor from two 64b unsigned integers + CUTE_HOST_DEVICE constexpr + uint128_t(uint64_t lo_, uint64_t hi_) : hilo_{lo_, hi_} {} + + /// Optional constructor from native value +#if defined(CUTE_UINT128_NATIVE) + uint128_t(unsigned __int128 value) : native(value) { } +#endif + + /// Lossily cast to uint64 + CUTE_HOST_DEVICE constexpr + explicit operator uint64_t() const + { + return hilo_.lo; + } + + template + CUTE_HOST_DEVICE constexpr + static void exception() + { + //static_assert(sizeof(Dummy) == 0, "Not implemented exception!"); + //abort(); + //printf("uint128 not implemented!\n"); + } + + /// Add + CUTE_HOST_DEVICE constexpr + uint128_t operator+(uint128_t const& rhs) const + { + uint128_t y; +#if defined(CUTE_UINT128_NATIVE) + y.native = native + rhs.native; +#else + y.hilo_.lo = hilo_.lo + rhs.hilo_.lo; + y.hilo_.hi = hilo_.hi + rhs.hilo_.hi + (!y.hilo_.lo && (rhs.hilo_.lo)); +#endif + return y; + } + + /// Subtract + CUTE_HOST_DEVICE constexpr + uint128_t operator-(uint128_t const& rhs) const + { + uint128_t y; +#if defined(CUTE_UINT128_NATIVE) + y.native = native - rhs.native; +#else + y.hilo_.lo = hilo_.lo - rhs.hilo_.lo; + y.hilo_.hi = hilo_.hi - rhs.hilo_.hi - (rhs.hilo_.lo && y.hilo_.lo > hilo_.lo); +#endif + return y; + } + + /// Multiply by unsigned 64b integer yielding 128b integer + CUTE_HOST_DEVICE constexpr + uint128_t operator*(uint64_t const& rhs) const + { + uint128_t y; +#if defined(CUTE_UINT128_NATIVE) + y.native = native * rhs; +#elif defined(CUTE_INT128_ARITHMETIC) + // Multiply by the low part + y.hilo_.lo = _umul128(hilo_.lo, rhs, &y.hilo_.hi); + + // Add the high part and ignore the overflow + uint64_t overflow; + y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow); +#else + exception(); +#endif + return y; + } + + /// Divide 128b operation by 64b operation yielding a 64b quotient + CUTE_HOST_DEVICE constexpr + uint64_t operator/(uint64_t const& divisor) const + { + uint64_t quotient = 0; +#if defined(CUTE_UINT128_NATIVE) + quotient = uint64_t(native / divisor); +#elif defined(CUTE_INT128_ARITHMETIC) + // implemented using MSVC's arithmetic intrinsics + uint64_t remainder = 0; + quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); +#else + exception(); +#endif + return quotient; + } + + /// Divide 128b operation by 64b operation yielding a 64b quotient + CUTE_HOST_DEVICE constexpr + uint64_t operator%(uint64_t const& divisor) const + { + uint64_t remainder = 0; +#if defined(CUTE_UINT128_NATIVE) + remainder = uint64_t(native % divisor); +#elif defined(CUTE_INT128_ARITHMETIC) + // implemented using MSVC's arithmetic intrinsics + (void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); +#else + exception(); +#endif + return remainder; + } + + /// Computes the quotient and remainder in a single method. + CUTE_HOST_DEVICE constexpr + uint64_t divmod(uint64_t &remainder, uint64_t divisor) const + { + uint64_t quotient = 0; +#if defined(CUTE_UINT128_NATIVE) + quotient = uint64_t(native / divisor); + remainder = uint64_t(native % divisor); +#elif defined(CUTE_INT128_ARITHMETIC) + // implemented using MSVC's arithmetic intrinsics + quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); +#else + exception(); +#endif + return quotient; + } + + /// Left-shifts a 128b unsigned integer + CUTE_HOST_DEVICE constexpr + uint128_t operator<<(int sh) const + { + if (sh == 0) { + return *this; + } + else if (sh >= storage_bits_) { + return uint128_t(0, hilo_.lo << (sh - storage_bits_)); + } + else { + return uint128_t( + (hilo_.lo << sh), + (hilo_.hi << sh) | uint64_t(hilo_.lo >> (storage_bits_ - sh)) + ); + } + } + + /// Right-shifts a 128b unsigned integer + CUTE_HOST_DEVICE constexpr + uint128_t operator>>(int sh) const + { + if (sh == 0) { + return *this; + } + else if (sh >= storage_bits_) { + return uint128_t((hilo_.hi >> (sh - storage_bits_)), 0); + } + else { + return uint128_t( + (hilo_.lo >> sh) | (hilo_.hi << (storage_bits_ - sh)), + (hilo_.hi >> sh) + ); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp new file mode 100644 index 0000000000..40ce5d1aef --- /dev/null +++ b/include/cute/pointer.hpp @@ -0,0 +1,322 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include + +namespace cute +{ + +// +// has_dereference to determine if a type is a pointer concept +// + +template +struct has_dereference : std::false_type { +}; + +template +struct has_dereference())>> : std::true_type { +}; + +// +// Pointer categories +// + +template +struct is_gmem : false_type {}; + +template +struct is_smem : false_type {}; + +// Anything that is not gmem or smem is rmem +template +struct is_rmem : bool_constant< not (is_gmem::value || is_smem::value)> {}; + +// +// A very simplified wrapper for pointers -- use for constructing tagged pointers +// +template +struct device_ptr +{ + using value_type = T; + + CUTE_HOST_DEVICE constexpr + device_ptr(T* ptr) : ptr_(ptr) {} + + CUTE_HOST_DEVICE constexpr + T* get() const { return ptr_; } + + CUTE_HOST_DEVICE constexpr + T& operator*() const { return *ptr_; } + + template + CUTE_HOST_DEVICE constexpr + T& operator[](Index const& i) const { return ptr_[i]; } + + template + CUTE_HOST_DEVICE constexpr + DerivedType operator+(Index const& i) const { return {ptr_ + i}; } + + CUTE_HOST_DEVICE constexpr friend + std::ptrdiff_t operator-(device_ptr const& a, + device_ptr const& b) { + return a.ptr_ - b.ptr_; + } + + T* ptr_; +}; + +// +// gmem_ptr +// + +template +struct gmem_ptr : device_ptr> { + using device_ptr>::device_ptr; +}; + +template +CUTE_HOST_DEVICE constexpr +gmem_ptr +make_gmem_ptr(T* ptr) { + return {ptr}; +} + +template +CUTE_HOST_DEVICE constexpr +gmem_ptr +make_gmem_ptr(void* ptr) { + return {reinterpret_cast(ptr)}; +} + +template +struct is_gmem> : true_type {}; + +// +// smem_ptr +// + +template +struct smem_ptr : device_ptr> { + using device_ptr>::device_ptr; +}; + +template +CUTE_HOST_DEVICE constexpr +smem_ptr +make_smem_ptr(T* ptr) { + return {ptr}; +} + +template +CUTE_HOST_DEVICE constexpr +smem_ptr +make_smem_ptr(void* ptr) { + return {reinterpret_cast(ptr)}; +} + +template +struct is_smem> : true_type {}; + +// +// rmem_ptr +// + +template +struct rmem_ptr : device_ptr> { + using device_ptr>::device_ptr; +}; + +template +CUTE_HOST_DEVICE constexpr +rmem_ptr +make_rmem_ptr(T* ptr) { + return {ptr}; +} + +template +CUTE_HOST_DEVICE constexpr +rmem_ptr +make_rmem_ptr(void* ptr) { + return {reinterpret_cast(ptr)}; +} + +template +struct is_rmem> : true_type {}; + +// +// counting iterator -- quick and dirty +// + +struct counting +{ + using index_type = int; + using value_type = index_type; + + CUTE_HOST_DEVICE constexpr + counting() : n_(0) {} + CUTE_HOST_DEVICE constexpr + counting(index_type const& n) : n_(n) {} + + CUTE_HOST_DEVICE constexpr + index_type operator[](index_type const& i) const { return n_ + i; } + + CUTE_HOST_DEVICE constexpr + index_type const& operator*() const { return n_; } + + CUTE_HOST_DEVICE constexpr + counting operator+(index_type const& i) const { return {n_ + i}; } + CUTE_HOST_DEVICE constexpr + counting& operator++() { ++n_; return *this; } + + CUTE_HOST_DEVICE constexpr + bool operator==(counting const& other) const { return n_ == other.n_; } + CUTE_HOST_DEVICE constexpr + bool operator!=(counting const& other) const { return n_ != other.n_; } + + CUTE_HOST_DEVICE constexpr + bool operator< (counting const& other) const { return n_ < other.n_; } + + index_type n_; +}; + +// +// recast +// + +template +CUTE_HOST_DEVICE constexpr +auto +recast(T* ptr) { + return reinterpret_cast(ptr); +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast(T const* ptr) { + return reinterpret_cast(ptr); +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast(gmem_ptr const& ptr) { + return make_gmem_ptr(recast(ptr.ptr_)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast(gmem_ptr const& ptr) { + return make_gmem_ptr(recast(ptr.ptr_)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast(smem_ptr const& ptr) { + return make_smem_ptr(recast(ptr.ptr_)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast(smem_ptr const& ptr) { + return make_smem_ptr(recast(ptr.ptr_)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast(rmem_ptr const& ptr) { + return make_rmem_ptr(recast(ptr.ptr_)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast(rmem_ptr const& ptr) { + return make_rmem_ptr(recast(ptr.ptr_)); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(T const* const ptr) +{ + printf("raw_ptr_%db(%p)", int(8*sizeof(T)), ptr); +} + +template +CUTE_HOST_DEVICE void print(gmem_ptr const& ptr) +{ + printf("gmem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get()); +} + +template +CUTE_HOST_DEVICE void print(smem_ptr const& ptr) +{ + printf("smem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get()); +} + +template +CUTE_HOST_DEVICE void print(rmem_ptr const& ptr) +{ + printf("rmem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get()); +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr const& ptr) +{ + return os << "gmem_ptr_" << int(8*sizeof(T)) << "b"; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr const& ptr) +{ + return os << "smem_ptr_" << int(8*sizeof(T)) << "b"; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr const& ptr) +{ + return os << "rmem_ptr_" << int(8*sizeof(T)) << "b"; +} + +} // end namespace cute diff --git a/include/cute/stride.hpp b/include/cute/stride.hpp new file mode 100644 index 0000000000..5fb0da8aec --- /dev/null +++ b/include/cute/stride.hpp @@ -0,0 +1,411 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +/** crd2idx maps a coordinate within to an index + * This is computed as follows: + * [coord, shape, and stride are all integers => step forward by stride] + * op(c, s, d) => c * d + * [coord is integer, shape and stride are tuple => divmod coord for each mode] + * op(c, (s,S), (d,D)) => op(c % prod(s), s, d) + op(c / prod(s), (S), (D)) + * [coord, shape, and stride are all tuples => consider each mode independently] + * op((c,C), (s,S), (d,D)) => op(c, s, d) + op((C), (S), (D)) + */ + +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx(Coord const& coord, + Shape const& shape, + Stride const& stride); + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx_ttt(Coord const& coord, + Shape const& shape, + Stride const& stride, seq) +{ + return (... + crd2idx(get(coord), get(shape), get(stride))); +} + +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx_itt(CInt const& coord, + STuple const& shape, + DTuple const& stride, seq) +{ + if constexpr (sizeof...(Is) == 0) { // Avoid recursion and mod on single/last iter + return crd2idx(coord, get(shape), get(stride)); + } else { // General case + return crd2idx(coord % product(get(shape)), get(shape), get(stride)) + + crd2idx_itt(coord / product(get(shape)), shape, stride, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx(Coord const& coord, + Shape const& shape, + Stride const& stride) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple tuple + static_assert(tuple_size::value == tuple_size< Shape>::value, "Mismatched Ranks"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return detail::crd2idx_ttt(coord, shape, stride, tuple_seq{}); + } else { // tuple "int" "int" + static_assert(sizeof(Coord) == 0, "Invalid parameters"); + } + } else { + if constexpr (is_tuple::value) { // "int" tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return detail::crd2idx_itt(coord, shape, stride, tuple_seq{}); + } else { // "int" "int" "int" + return coord * stride; + } + } + + CUTE_GCC_UNREACHABLE; +} + +// +// If we know Stride is default [CompactColMajor], then we can take shortcuts +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx_horner(CTuple const& coord, + STuple const& shape, seq) +{ + if constexpr (sizeof...(Is) == 0) { // No recursion on single/last iter + return get(coord); + } else { // General case + return get(coord) + get(shape) * crd2idx_horner(coord, shape, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx(Coord const& coord, + Shape const& shape) +{ + static_assert(decltype(congruent(coord,shape))::value, "Mismatched Ranks"); + if constexpr (is_tuple::value) { + // Flatten and apply Horner's method + auto flat_coord = flatten(coord); + auto flat_shape = flatten(shape); + return detail::crd2idx_horner(flat_coord, flat_shape, tuple_seq{}); + } else { + return coord; + } + + CUTE_GCC_UNREACHABLE; +} + +/** idx2crd splits an index to a coordinate within . + * + * This is computed as follows: + * [index, shape, and stride are all integers => determine 1D coord] + * op(i, s, d) => (i / d) % s + * [index is integer, shape and stride are tuple => determine component for each mode] + * op(i, (s,S), (d,D)) => (op(i, s, d), op(i, S, D)...) + * [index, shape, and stride are all tuples => consider each mode independently] + * op((i,I), (s,S), (d,D)) => (op(i, s, d), op((I), (S), (D))) + * + * NOTE: This only works for compact shape+stride layouts. A more general version would + * apply to all surjective layouts + */ + +template +CUTE_HOST_DEVICE constexpr +auto +idx2crd(Index const& idx, + Shape const& shape, + Stride const& stride) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple tuple + static_assert(tuple_size::value == tuple_size< Shape>::value, "Mismatched Ranks"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return transform(idx, shape, stride, [](auto const& i, auto const& s, auto const& d){ return idx2crd(i,s,d); }); + } else { // tuple "int" "int" + static_assert(sizeof(Index) == 0, "Invalid parameters"); + } + } else { + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // "int" tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return transform(shape, stride, [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); }); + } else { // "int" tuple "int" + return transform(shape, compact_col_major(shape, stride), [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); }); + } + } else { // "int" "int" "int" + return (idx / stride) % shape; + } + } + + CUTE_GCC_UNREACHABLE; +} + +// +// If we know Stride is default [CompactColMajor], then we can take shortcuts +// + +//(idx / 1) % s0 +//(idx / s0) % s1 +//(idx / (s0 * s1)) % s2 +//... + +template +CUTE_HOST_DEVICE constexpr +auto +idx2crd(Index const& idx, + Shape const& shape) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return transform(idx, shape, [](auto const& i, auto const& s) { return idx2crd(i,s); }); + } else { // tuple "int" + static_assert(sizeof(Index) == 0, "Invalid parameters"); + } + } else { + if constexpr (is_tuple::value) { // "int" tuple + return idx2crd(idx, shape, compact_col_major(shape)); + } else { // "int" "int" + return idx; + } + } + + CUTE_GCC_UNREACHABLE; +} + +// +// crd2crd +// + +template +CUTE_HOST_DEVICE constexpr +auto +crd2crd(Coord const& coord, + SShape const& src_shape, + DShape const& dst_shape) +{ + if constexpr (is_tuple::value && is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return transform(coord, src_shape, dst_shape, [](auto const& c, auto const& s, auto const& d) { return crd2crd(c,s,d); }); + } else { + // assert(size(src_shape) == size(dst_shape)) + return idx2crd(crd2idx(coord, src_shape), dst_shape); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Compact Major +// + +// General tag for common layouts and dispatching +struct GenColMajor {}; +struct GenRowMajor {}; + +template , class Major = GenColMajor> +CUTE_HOST_DEVICE constexpr +auto +compact_major(Shape const& shape, + Current const& current = {}, + Major const& major = {}); + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +compact_major_ti(Shape const& shape, + Current const& current, + GenColMajor const& major, seq) +{ + return cute::make_tuple(compact_major(get(shape), current * product<0,Is>(shape), major)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +compact_major_ti(Shape const& shape, + Current const& current, + GenRowMajor const& major, seq) +{ + constexpr int E = tuple_size::value; + return cute::make_tuple(compact_major(get(shape), current * product(shape), major)...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +compact_major(Shape const& shape, + Current const& current, + Major const& major) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return transform(shape, current, [&](auto const& s, auto const& c){ return compact_major(s,c,major); }); + } else { // tuple int + return detail::compact_major_ti(shape, current, major, tuple_seq{}); + } + } else { + if constexpr (is_tuple::value) { // int tuple + static_assert(sizeof(Shape) == 0, "Invalid parameters"); + } else { // int int + if constexpr (is_constant<1, Shape>::value) { + return Int<0>{}; // If current is dynamic, this could save a reg + } else { + return current; + } + } + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Compact Col Major +// + +template > +CUTE_HOST_DEVICE constexpr +auto +compact_col_major(Shape const& shape, + Current const& current = {}) +{ + return compact_major(shape, current, GenColMajor{}); +} + +template +using ColMajor = decltype(compact_col_major(std::declval())); + +// +// Compact Row Major +// + +template > +CUTE_HOST_DEVICE constexpr +auto +compact_row_major(Shape const& shape, + Current const& current = {}) +{ + return compact_major(shape, current, GenRowMajor{}); +} + +template +using RowMajor = decltype(compact_row_major(std::declval())); + +// +// Compact Order -- compute a compact stride based on an ordering of the modes +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +compact_order(Shape const& shape, Order const& order, + OrigShape const& orig_shape, OrigOrder const& orig_order) +{ + if constexpr (is_tuple::value) { + return transform(shape, order, [&](auto const& x, auto const& y) { return compact_order(x, y, orig_shape, orig_order); }); + } else { + auto d = product(transform(orig_shape, orig_order, + [&](auto const& s, auto const& o) { + return conditional_return(o < order, product(s), Int<1>{}); + })); + return compact_col_major(shape, d); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +compact_order(Shape const& shape, Order const& order) +{ + static_assert(is_congruent::value, "Need congruence of shape and order."); + return detail::compact_order(shape, order, flatten_to_tuple(shape), flatten_to_tuple(order)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +compact_order(Shape const& shape, GenColMajor const& major) +{ + return compact_major(shape, Int<1>{}, major); +} + +template +CUTE_HOST_DEVICE constexpr +auto +compact_order(Shape const& shape, GenRowMajor const& major) +{ + return compact_major(shape, Int<1>{}, major); +} + +} // end namespace cute diff --git a/include/cute/swizzle.hpp b/include/cute/swizzle.hpp new file mode 100644 index 0000000000..0a13e55143 --- /dev/null +++ b/include/cute/swizzle.hpp @@ -0,0 +1,497 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace cute +{ + +// A generic Swizzle functor +/* 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx + * ^--^ MBase is the number of least-sig bits to keep constant + * ^-^ ^-^ BBits is the number of bits in the mask + * ^---------^ SShift is the distance to shift the YYY mask + * (pos shifts YYY to the right, neg shifts YYY to the left) + * + * e.g. Given + * 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx + * the result is + * 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY + */ +template +struct Swizzle +{ + static constexpr int num_bits = BBits; + static constexpr int num_base = MBase; + static constexpr int num_shft = SShift; + + static_assert(num_base >= 0, "MBase must be positive."); + static_assert(num_bits >= 0, "BBits must be positive."); + static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits."); + + // using 'int' type here to avoid unintentially casting to unsigned... unsure. + using bit_msk = cute::constant; + using yyy_msk = cute::constant; + using zzz_msk = cute::constant; + using msk_sft = cute::constant; + + static constexpr uint32_t swizzle_code = uint32_t(yyy_msk{} | zzz_msk{}); + + template ::value)> + CUTE_HOST_DEVICE constexpr static + auto + apply(Offset const& offset) + { + return offset ^ shiftr(offset & yyy_msk{}, msk_sft{}); // ZZZ ^= YYY + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + operator()(Offset const& offset) const + { + return apply(offset); + } +}; + +// Translation for legacy SwizzleXor +// TODO: Deprecate +template +using SwizzleXor = Swizzle; + +// +// make_swizzle<0b1000, 0b0100>() -> Swizzle<1,2,1> +// make_swizzle<0b11000000, 0b00000110>() -> Swizzle<2,1,5> +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_swizzle() +{ + constexpr uint32_t BZ = popcount(Y); // Number of swizzle bits + constexpr uint32_t BY = popcount(Z); // Number of swizzle bits + static_assert(BZ == BY, "Number of bits in Y and Z don't match"); + constexpr uint32_t TZ_Y = countr_zero(Y); // Number of trailing zeros in Y + constexpr uint32_t TZ_Z = countr_zero(Z); // Number of trailing zeros in Z + constexpr uint32_t M = cute::min(TZ_Y, TZ_Z) % 32; + constexpr int32_t S = int32_t(TZ_Y) - int32_t(TZ_Z); // Difference in trailing zeros + static_assert((Y | Z) == Swizzle::swizzle_code, "Something went wrong."); + return Swizzle{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Swizzle, Swizzle) +{ + static_assert(S0 == S1, "Can only merge swizzles of the same shift."); + constexpr uint32_t Y = Swizzle::yyy_msk::value ^ Swizzle::yyy_msk::value; + constexpr uint32_t Z = Swizzle::zzz_msk::value ^ Swizzle::zzz_msk::value; + return make_swizzle(); + + //return ComposedFn, Swizzle>{}; +} + +// +// Upcast and Downcast +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Swizzle const& swizzle) +{ + static_assert(has_single_bit(N), "N must be a power of two"); + constexpr int log2_n = bit_width(uint32_t(N)) - 1; + constexpr int NewM = M - log2_n; + if constexpr (NewM >= 0) { + return Swizzle{}; + } else { + return Swizzle{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(Swizzle const& swizzle) +{ + static_assert(has_single_bit(N), "N must be a power of two"); + constexpr int log2_n = bit_width(uint32_t(N)) - 1; + return Swizzle{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast(Swizzle const& swizzle) +{ + if constexpr (sizeof_bits::value == sizeof_bits::value) { + return swizzle; + } else if constexpr (sizeof_bits::value > sizeof_bits::value) { + static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a multiple of OldType"); + return upcast::value/sizeof_bits::value>(swizzle); + } else if constexpr (sizeof_bits::value < sizeof_bits::value) { + static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a divisor of OldType"); + return downcast::value/sizeof_bits::value>(swizzle); + } +} + +// +// Utility for slicing and swizzle "offsets" +// + +// For swizzle functions, it is often needed to keep track of which bits are +// consumed and which bits are free. Furthermore, it is useful to know whether +// each of these bits is known statically or dynamically. + +// MixedBits is an integer class where some bits are known statically and some +// bits are known dynamically. These sets of bits are disjoint and it is known +// statically which bits are known dynamically. + +// MixedBits can only be manipulated through bitwise operations + +// Abstract value: StaticInt | (dynamic_int_ & StaticFlags) +template // 0: static, 1: dynamic +struct MixedBits +{ + // Representation invariants + static_assert(StaticFlags != 0, "Should be at least one dynamic bit in MixedBits."); + static_assert((StaticInt & StaticFlags) == 0, "No static/dynamic overlap allowed in MixedBits."); + // assert((dynamic_int_ & ~F) == 0); + + DynamicType dynamic_int_; +}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_mixed_bits(constant const&, DynamicType const& d, constant const&) +{ + static_assert(is_integral::value); + if constexpr (is_static::value) { + static_assert((s & DynamicType::value & f) == 0, "No static/dynamic overlap allowed."); + return constant{} | (d & constant{}); // Just return a static int + } else if constexpr (f == 0) { + return constant{}; // Just return a static int + } else { + return MixedBits{d & f}; // MixedBits + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Explicit conversion for now -- consider casting on plus or minus +// + +template +CUTE_HOST_DEVICE constexpr +auto +to_integral(MixedBits const& m) +{ + //return S | (m.dynamic_int_ & F); + return S | m.dynamic_int_; +} + +// Any cute::is_integral +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +to_integral(I const& i) +{ + return i; +} + +// +// Operators +// + +// Equality +template +CUTE_HOST_DEVICE constexpr +auto +operator==(MixedBits const& m, constant const&) +{ + return (S0 == (S1 & ~F0)) && (m.dynamic_int_ == (S1 & F0)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator==(constant const& s, MixedBits const& m) +{ + return m == s; +} + +// Bitwise AND +template +CUTE_HOST_DEVICE constexpr +auto +operator&(MixedBits const& m0, MixedBits const& m1) +{ + // Truth table for (S0,D0,F0) & (S1,D1,F1) -> (S,D,F) + // S0D0F0 | 0X0 | 001 | 011 | 1X0 | + // S1D1F1 + // 0X0 | 0X0 | 0X0 | 0X0 | 0X0 | + // 001 | 0X0 | 001 | 001 | 001 | + // 011 | 0X0 | 001 | 011 | 011 | + // 1X0 | 0X0 | 001 | 011 | 1X0 | + + return make_mixed_bits(constant{}, + //(S0 | m0.dynamic_int_) & (S1 | m1.dynamic_int_), + ((S1 & F0) & m0.dynamic_int_) | ((S0 & F1) & m1.dynamic_int_) | (m0.dynamic_int_ & m1.dynamic_int_), + constant{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator&(MixedBits const& m, constant const&) +{ + return make_mixed_bits(constant{}, + m.dynamic_int_, + constant{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator&(constant const& s, MixedBits const& m) +{ + return m & s; +} + +// Bitwise OR +template +CUTE_HOST_DEVICE constexpr +auto +operator|(MixedBits const& m0, MixedBits const& m1) +{ + // Truth table for (S0,D0,F0) | (S1,D1,F1) -> (S,D,F) + // S0D0F0 | 0X0 | 001 | 011 | 1X0 | + // S1D1F1 + // 0X0 | 0X0 | 001 | 011 | 1X0 | + // 001 | 001 | 001 | 011 | 1X0 | + // 011 | 011 | 011 | 011 | 1X0 | + // 1X0 | 1X0 | 1X0 | 1X0 | 1X0 | + + return make_mixed_bits(constant{}, + ((~S1 & F0) & m0.dynamic_int_) | ((~S0 & F1) & m1.dynamic_int_), + constant{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator|(MixedBits const& m, constant const&) +{ + return make_mixed_bits(constant{}, + m.dynamic_int_, + constant{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator|(constant const& s, MixedBits const& m) +{ + return m | s; +} + +// Bitwise XOR +template +CUTE_HOST_DEVICE constexpr +auto +operator^(MixedBits const& m0, MixedBits const& m1) +{ + // Truth table for (S0,D0,F0) ^ (S1,D1,F1) -> (S,D,F) + // S0D0F0 | 0X0 | 001 | 011 | 1X0 | + // S1D1F1 + // 0X0 | 0X0 | 001 | 011 | 1X0 | + // 001 | 001 | 001 | 011 | 011 | + // 011 | 011 | 011 | 001 | 001 | + // 1X0 | 1X0 | 011 | 001 | 0X0 | + + return make_mixed_bits(constant{}, + (S0 | m0.dynamic_int_) ^ (S1 | m1.dynamic_int_), + constant{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator^(MixedBits const& m, constant const&) +{ + return make_mixed_bits(constant{}, + (S0 | m.dynamic_int_) ^ S1, + constant{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator^(constant const& s, MixedBits const& m) +{ + return m ^ s; +} + +// +// upcast and downcast +// + +template +CUTE_HOST_DEVICE constexpr +auto +safe_div(MixedBits const& m, constant const& s) +{ + static_assert(has_single_bit(S1), "Only divide MixedBits by powers of two."); + return make_mixed_bits(safe_div(constant{}, s), + safe_div(m.dynamic_int_, s), + safe_div(constant{}, s)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(MixedBits const& m) +{ + static_assert(has_single_bit(N), "Only divide MixedBits by powers of two."); + return safe_div(m, constant{}); +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +upcast(T const& m) +{ + return safe_div(m, constant{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(MixedBits const& m) +{ + static_assert(has_single_bit(N), "Only scale MixedBits by powers of two."); + return make_mixed_bits(constant{}, + m.dynamic_int_ * N, + constant{}); +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +downcast(T const& m) +{ + return m * constant{}; +} + +// +// Convert a Pow2Layout+Coord to a MixedBits +// + +template +CUTE_HOST_DEVICE constexpr +auto +to_mixed_bits(Shape const& shape, Stride const& stride, Coord const& coord) +{ + if constexpr (is_tuple::value && is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform_apply(shape, stride, coord, [](auto const& s, auto const& d, auto const& c) { return to_mixed_bits(s,d,c); }, + [](auto const&... a) { return (a ^ ...); }); + } else if constexpr (is_integral::value && is_integral::value && is_integral::value) { + static_assert(decltype(shape*stride)::value == 0 || has_single_bit(decltype(shape*stride)::value), "Requires pow2 shape*stride."); + return make_mixed_bits(Int<0>{}, coord * stride, (shape - Int<1>{}) * stride); + } else { + static_assert(is_integral::value && is_integral::value && is_integral::value, "Either Shape, Stride, and Coord must be all tuples, or they must be all integral (in the sense of cute::is_integral)."); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_mixed_bits(Layout const& layout, Coord const& coord) +{ + return to_mixed_bits(layout.shape(), layout.stride(), idx2crd(coord, layout.shape())); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(MixedBits const& m) +{ + printf("M_%u|(%u&%u)=%u", S, uint32_t(m.dynamic_int_), F, to_integral(m)); +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits const& m) +{ + return os << "M_" << S << "|(" << uint32_t(m.dynamic_int_) << "&" << F << ")=" << to_integral(m); +} + +template +CUTE_HOST_DEVICE void print(Swizzle const&) +{ + print("S<%d,%d,%d>", B, M, S); +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, Swizzle const&) +{ + return os << "S<" << B << "," << M << "," << S << ">"; +} + +} // end namespace cute diff --git a/include/cute/swizzle_layout.hpp b/include/cute/swizzle_layout.hpp new file mode 100644 index 0000000000..1376a47ddd --- /dev/null +++ b/include/cute/swizzle_layout.hpp @@ -0,0 +1,1010 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#include + +/* This implements a ComposedLayout of the form + * InvolutionFn o OffsetPlus o Layout + * where the InvolutionFn need not be linear (hence the need for the Offset). + * + * This ComposedLayout provides similar coordinate-to-index mapping and layout manipulations, + * but is not considered a "normal" layout. + * For example, this layout provides size() functions, but does not provide stride() functions. + * + * Furthermore, for known InvolutionFns, this layout attempts to decay itself + * to a normal-layout with dynamic or static strides. + * This is possible by determining the subdomain of the Involution function + * that is identity and testing if the right Layout's codomain is contained + * within it. + */ + +namespace cute +{ + +// A Layout of non-trivially composable functions: F o I o L +template +struct ComposedLayout + : private cute::tuple // EBO for static layouts +{ + CUTE_HOST_DEVICE constexpr + ComposedLayout(InvolutionFn const& fn = {}, + IntermediateOffset const& offset = {}, + Layout const& layout = {}) + : cute::tuple(fn, offset, layout) + {} + + // + // Accessors + // + + static constexpr int rank = Layout::rank; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + swizzle_fn() const { + return get<0>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + offset_fn() const { + return get<1>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout_fn() const { + return get<2>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return layout_fn().shape(); + } + + // Doesn't really make sense to ask for the strides of this "layout" + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const = delete; + + // + // Mappings + // + + // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + // OR + // Slice the layout and return the sublayout (Coord has an Underscore slice op) + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + return slice(coord, *this); + } else { + return swizzle_fn()(to_integral(offset_fn()) + layout_fn()(coord)); // (F o L)(c) + } + + CUTE_GCC_UNREACHABLE; + } + + // Map a 1D linear coordinate to a flat ND logical coordinate + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + operator[](Int const& linear_idx) const { + return get_flat_coord(linear_idx); + } + + // Convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(OtherLayout const& other) const { + return composition(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return composition(*this, make_tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(OtherShape const& shape) const { + return composition(*this, make_layout(shape)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(Shapes const&... shapes) const { + return composition(*this, make_layout(make_shape(shapes...))); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(OtherLayout const& other) const { + return tiled_divide(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return tiled_divide(*this, make_tile(layouts...)); + } + + // + // Utility + // + + // + // Index to Coordinate + // + + // NOTE Only valid for compact layouts + + // Return the (hierarchical) ND logical coordinate corresponding to the linear index + // @post this->crd2idx(@a result) == idx + // @post congruent(@a result, shape()) + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_hier_coord(IInt const& idx) const { + return layout_fn().get_hier_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) + } + + // Return the (flat) ND logical coordinate corresponding to the linear index + // @post this->crd2idx(@a result) == idx + // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_flat_coord(IInt const& idx) const { + return layout_fn().get_flat_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) + } + + // Return the generalized column-major 1D logical coordinate corresponding to the linear index + // @post this->crd2idx(@a result) == idx + // @post is_integral::value + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_1d_coord(IInt const& idx) const { + return layout_fn().get_1d_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) + } +}; + +template +struct is_layout> : true_type {}; + +template +struct is_composed_layout : false_type {}; +template +struct is_composed_layout> : true_type {}; + +// +// Constructors +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Swizzle const& sxor) +{ + return composition(sxor, Layout,Int<1>>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(ComposedLayout const& a, Layout const& b) +{ + return composition(a.swizzle_fn(), a.offset_fn(), make_layout(a.layout_fn(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Layout const& a, ComposedLayout const& b) +{ + return composition(b.swizzle_fn(), b.offset_fn(), make_layout(a, b.layout_fn())); +} + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +transfer_swizzle(Layout const& old_layout, + Layout const& new_layout) +{ + // Our goal is to determine a new swizzle for the strides in new_layout for consistent vectorizations + + // This is accomplished by identifying + // S o L :=: S? o L* + // We identify the "active" portion of S by computing (P o L)(c*) where P is a projection generated by S + // Then that active identifier is transformed through the layouts: + // L*(L[(P o L)(c*)]) + // which is a new swizzle identifier for S?, the new swizzle + + // Projections of the swizzle layout for composition, P + auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}), + make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{})); + + // Compose with the tile to get the swizzle projection, P o L [The Z and Y contributing portions of L] + auto layout_only_zy = composition(swizzle_only_zy, old_layout); + // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) + auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{}); + + // Get the Z bit and the Y bits -- keep only those that are active in Z *and* Y + auto zzz_msk = typename Swizzle::zzz_msk{}; + auto yyy_msk = typename Swizzle::yyy_msk{}; + auto msk_sft = typename Swizzle::msk_sft{}; + auto active_Z = swizzle_active_bits & shiftr(swizzle_active_bits, msk_sft) & zzz_msk; + auto active_Y = swizzle_active_bits & shiftr(swizzle_active_bits, -msk_sft) & yyy_msk; + + // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) + auto new_active_Z = new_layout(old_layout.get_1d_coord(active_Z)); + auto new_active_Y = new_layout(old_layout.get_1d_coord(active_Y)); + + // Use this new swizzle identifier to construct the new swizzle for new_layout + // (this also makes sure it's a "valid" swizzle that Swizzle can represent) + return composition(make_swizzle(), new_layout); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(ComposedLayout,Offset,Layout> const& layout) +{ + return detail::transfer_swizzle(layout.layout_fn(), make_fragment_like(layout.layout_fn())); +} + +// +// Utilities +// + +// Return the layout of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +layout(ComposedLayout const& clayout) +{ + return composition(clayout.swizzle_fn(), clayout.offset_fn(), layout(clayout.layout_fn())); +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(ComposedLayout const& layout) +{ + return shape(layout.layout_fn()); +} + +// Doesn't make sense to directly ask for the strides of this "layout" +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(ComposedLayout const& layout) = delete; + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +size(ComposedLayout const& layout) +{ + return size(layout.layout_fn()); +} + +// Return the number of modes +template +CUTE_HOST_DEVICE constexpr +auto +rank(ComposedLayout const& layout) +{ + return rank(layout.layout_fn()); +} + +// Return the depth of the layout +template +CUTE_HOST_DEVICE constexpr +auto +depth(ComposedLayout const& layout) +{ + return depth(layout.layout_fn()); +} + +// Return the codomain size of a mode +template +CUTE_HOST_DEVICE constexpr +auto +cosize(ComposedLayout const& layout) +{ + return cosize(layout.layout_fn()); +} + +// +// Operations to manipulate Layouts like a tuple of pairs +// + +template +CUTE_HOST_DEVICE constexpr +auto +get(ComposedLayout const& a) +{ + return composition(a.swizzle_fn(), a.offset_fn(), get(a.layout_fn())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +take(ComposedLayout const& a) +{ + return composition(a.swizzle_fn(), a.offset_fn(), take(a.layout_fn())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flatten(ComposedLayout const& a) +{ + return composition(a.swizzle_fn(), a.offset_fn(), flatten(a.layout_fn())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +append(ComposedLayout const& a, X const& x) +{ + return composition(a.swizzle_fn(), a.offset_fn(), append(a.layout_fn(), x)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group(ComposedLayout const& a) +{ + return composition(a.swizzle_fn(), a.offset_fn(), group(a.layout_fn())); +} + +// +// Slice a ComposedLayout +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +make_swizzle_strides(true_type, + IntZ const& Z, + IntY const& Y, + Offset const& offset, + int_sequence) +{ + // Below is an optimized/compressed version of: + //return make_tuple((swizzle(offset + Z*Int<(1 << I)>{}) - swizzle(offset))...); + // with knowledge of Swizzle, I... ranges for each B bits, + // and the layout won't slice along z-bits that are already set + + // y\z 0 1 + // 0 Z DC + // 1 -Z DC + + return cute::make_tuple(conditional_return((offset & (Y << Int{})) == Int<0>{}, Z << Int{}, -(Z << Int{}))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_swizzle_strides(false_type, + IntZ const& Z, + IntY const& Y, + Offset const& offset, + int_sequence) +{ + // Below is an optimized/compressed version of: + //return make_tuple((swizzle(offset + Y*Int<(1 << I)>{}) - swizzle(offset))...); + // with knowledge of Swizzle, I... ranges for each B bits, + // and the layout won't slice along y-bits that are already set + + // y\z 0 1 + // 0 Y+Z Y-Z + // 1 DC DC + + return cute::make_tuple(conditional_return((offset & (Z << Int{})) == Int<0>{}, (Y+Z) << Int{}, (Y-Z) << Int{})...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& coord, ComposedLayout,Offset,Layout> const& layout) +{ + if constexpr (all_underscore::value) { + // Skip the expensive/complicated attempt to decay to a normal layout and just reshape + return cute::make_tuple(composition(layout.swizzle_fn(), layout.offset_fn(), slice(coord, layout.layout_fn())), Int<0>{}); + } else { + + // Projections of the swizzle layout for composition + auto sw = make_layout(make_shape(Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B)>{}, Int<1>{})); + + auto swizzle_anti_zy = make_layout(shape(sw), + make_stride(stride<0>(sw), Int<0>{}, stride<2>(sw), Int<0>{}, size(sw))); + auto swizzle_only_zy = make_layout(shape(sw), + make_stride( Int<0>{}, stride<1>(sw), Int<0>{}, stride<3>(sw), Int<0>{})); + + // The portion of the layout that is not yet consumed + auto sliced_layout = slice(coord, layout.layout_fn()); + + // If the sliced_layout hits two bits that are swizzled together, then don't attempt to decay + + // Compose with the layout to get the swizzle projection, P o L [The Z and Y contributing portions of L] + // (this also tests that shape/stride of layout compose with swizzle) + auto sliced_layout_only_zy = composition(swizzle_only_zy, sliced_layout); + // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) + auto swizzle_active_bits = sliced_layout_only_zy(size(sliced_layout_only_zy)-Int<1>{}); + // Determine if any active bits collide under the swizzle + auto hit_ZandY = !(swizzle_active_bits & ~layout.swizzle_fn()(swizzle_active_bits)); + + // The portion of the layout that we are consuming now + auto diced_layout = dice(coord, layout.layout_fn()); + auto diced_coord = dice(coord, coord); + + auto diced_layout_anti_zy = composition(swizzle_anti_zy, diced_layout); + auto diced_layout_only_zy = composition(swizzle_only_zy, diced_layout); + + // New swizzle and offset + auto swizzle = layout.swizzle_fn(); + // offset_only_zy interacts with swizzle and gets accumulated with layout.offset_fn() + // being careful about the static/dynamic contributions from diced_layout and diced_coord + auto offset_only_zy = layout.offset_fn() ^ to_mixed_bits(diced_layout_only_zy, diced_coord); + // offset_anti_zy always gets passed through, no interaction with swizzle + auto offset_anti_zy = diced_layout_anti_zy(diced_coord); + + // If Layout's codomain hits on Y AND Z, then it's not reducible + // If Layout's codomain hits on Y XOR Z, then it's dynamic-normal + // If Layout's codomain hits on neither Y NOR Z, then it's static-normal + + // Test the sliced layout for hit_X & hit_Y for potential decay + if constexpr (is_constant::value) + { // Hits on Y AND Z, so it's not reducible + return cute::make_tuple(composition(swizzle, offset_only_zy, sliced_layout), offset_anti_zy); + } else + { // Misses on Y or Z, so it's static-normal or dynamic-normal + + // Lowest bit of the Z and Y masks + auto Z = typename Swizzle::zzz_msk{} & -typename Swizzle::zzz_msk{}; + auto Y = typename Swizzle::yyy_msk{} & -typename Swizzle::yyy_msk{}; + auto stride_lo = detail::make_swizzle_strides(Z < Y, Z, Y, offset_only_zy, make_int_sequence{}); + auto stride_hi = detail::make_swizzle_strides(Z > Y, Z, Y, offset_only_zy, make_int_sequence{}); + + // Construct a (dynamic) layout that we can perform the composition with + auto swizzle_layout = make_layout(make_shape (Int<(1 << M)>{}, repeat(Int<2>{}), Int<(1 << (abs(S)-B))>{}, repeat(Int<2>{}), Int< 1>{}), + make_stride(Int< 1>{}, stride_lo, Int<(1 << (M+B))>{}, stride_hi , Int<(1 << (M+B+abs(S)))>{})); + + // Decay to a normal layout with offset + return cute::make_tuple(composition(swizzle_layout, sliced_layout), + swizzle(to_integral(offset_only_zy)) + offset_anti_zy); + } + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +slice(Coord const& coord, ComposedLayout const& layout) +{ + return get<0>(slice_and_offset(coord, layout)); +} + +// +// composition +// + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Swizzle const& sxor, + Offset const& offset, + Layout const& layout) +{ + return ComposedLayout>{sxor, offset, layout}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Swizzle const& sxor, + Offset const& offset, + ComposedLayout const& layout) +{ + // Assume disjoint swizzles and offsets for commutivity + return composition(composition(sxor,layout.swizzle_fn()), offset ^ layout.offset_fn(), layout.layout_fn()); +} + +// Ignore identity case +template +CUTE_HOST_DEVICE constexpr +auto +composition(Swizzle<0,M,S> const&, + Int<0> const&, + Layout const& layout) +{ + return layout; +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Swizzle const& sxor, + Layout const& layout) +{ + return composition(sxor, Int<0>{}, layout); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(ComposedLayout const& a, + LayoutOrTile const& b) +{ + return composition(a.swizzle_fn(), a.offset_fn(), composition(a.layout_fn(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& a, + Swizzle const& b) +{ + // Get the Z bits and the Y bits + auto active_Y = a(typename Swizzle::yyy_msk{}); + auto active_Z = a(typename Swizzle::zzz_msk{}); + + // Works in simple cases... but could be greatly generalized + + return composition(make_swizzle(), a); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& a, + ComposedLayout const& b) +{ + CUTE_STATIC_ASSERT_V(b.offset_fn() == Int<0>{}, "Require Swizzle offset == 0."); + + return composition(composition(a, b.swizzle_fn()), b.layout_fn()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(ComposedLayout const& a, + ComposedLayout const& b) +{ + auto asb = composition(a.layout_fn(), b); + + return composition(composition(a.swizzle_fn(),asb.swizzle_fn()), asb.offset_fn(), asb.layout_fn()); +} + +// +// complement +// + +template +CUTE_HOST_DEVICE constexpr +auto +complement(ComposedLayout const& layout, CoSizeHi const& cosize_hi) +{ + // Assume there is no swizzle component in the complement + return complement(layout.layout_fn(), cosize_hi); +} + +template +CUTE_HOST_DEVICE constexpr +auto +complement(ComposedLayout const& layout) +{ + return complement(layout, cosize(layout)); +} + +// +// inverse +// + +template +CUTE_HOST_DEVICE constexpr +auto +right_inverse(ComposedLayout const& layout) +{ + CUTE_STATIC_ASSERT_V(layout.offset_fn() == Int<0>{}, "Requires 0-offset."); + return composition(right_inverse(layout.layout_fn()), layout.swizzle_fn()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +left_inverse(ComposedLayout const& layout) +{ + CUTE_STATIC_ASSERT_V(layout.offset_fn() == Int<0>{}, "Requires 0-offset."); + return composition(left_inverse(layout.layout_fn()), layout.swizzle_fn()); +} + +// +// Other operations +// + +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(ComposedLayout,Offset,SLayout> const& a, + Layout const& b) +{ + // This assumes that Offset is in the YZ domain of the Swizzle... + return cute::min(Int<(1 << M)>{}, max_common_vector(a.layout_fn(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(Layout const& a, + ComposedLayout,Offset,SLayout> const& b) +{ + return max_common_vector(b, a); +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(ComposedLayout,Offset0,SLayout0> const& a, + ComposedLayout,Offset1,SLayout1> const& b) +{ + auto result = coalesce(composition(a, right_inverse(b))); + + if constexpr (is_constant<1, decltype(stride<0>(result.layout_fn()))>::value) { + return shape<0>(result); + } else { + return Int<1>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +zip(ComposedLayout const& a) +{ + return composition(a.swizzle_fn(), a.offset_fn(), zip(a.layout_fn())); +} + +// Partitions + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.swizzle_fn(), a.offset_fn(), logical_divide(a.layout_fn(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_unzip(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.swizzle_fn(), a.offset_fn(), tile_unzip(a.layout_fn(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.swizzle_fn(), a.offset_fn(), tiled_divide(a.layout_fn(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.swizzle_fn(), a.offset_fn(), zipped_divide(a.layout_fn(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.swizzle_fn(), a.offset_fn(), logical_product(a.layout_fn(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tiled_product(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.swizzle_fn(), a.offset_fn(), tiled_product(a.layout_fn(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +blocked_product(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.swizzle_fn(), a.offset_fn(), blocked_product(a.layout_fn(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +raked_product(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.swizzle_fn(), a.offset_fn(), raked_product(a.layout_fn(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_to_shape(ComposedLayout const& layout, + Shape const& trg_shape, + ModeOrder const& ord_shape = {}) +{ + return composition(layout.swizzle_fn(), layout.offset_fn(), tile_to_shape(layout.layout_fn(), trg_shape, ord_shape)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter(ComposedLayout const& layout, Shape const& trg_profile) +{ + return composition(layout.swizzle_fn(), layout.offset_fn(), filter(layout.layout_fn(), trg_profile)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(ComposedLayout const& layout) +{ + return composition(layout.swizzle_fn(), layout.offset_fn(), coalesce(layout.layout_fn())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(ComposedLayout const& layout, Shape const& trg_profile) +{ + return composition(layout.swizzle_fn(), layout.offset_fn(), coalesce(layout.layout_fn(), trg_profile)); +} + +/////////////////////////////////////////////////////////////////////////////// +// ComposedLayout as second argument is often more difficult... + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(Layout const& block, + ComposedLayout,Offset,LayoutT> const& tile) +{ + CUTE_STATIC_ASSERT_V(tile.offset_fn() == Int<0>{}, "Require Swizzle offset == 0."); + // The new layout -- if swizzle wasn't an issue, this is the result + // our goal is to determine a new swizzle for these strides + auto new_layout = logical_product(block, tile.layout_fn()); + + // This is accomplished by identifying + // S o L :=: S? o L* + // We identify the "active" portion of S by computing (P o L)(c*) where P is a projection generated by S + // Then that active identifier is transformed through the layouts: + // L*(L[(P o L)(c*)]) + // which is a new swizzle identifier for S?, the new swizzle + + // Projections of the swizzle layout for composition, P + auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}), + make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{})); + + // Compose with the tile to get the swizzle projection, P o L [The Z and Y contributing portions of L] + auto layout_only_zy = composition(swizzle_only_zy, tile.layout_fn()); + // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) + auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{}); + // Get the Z bit and the Y bits + auto active_Z = swizzle_active_bits & typename Swizzle::zzz_msk{}; + auto active_Y = swizzle_active_bits & typename Swizzle::yyy_msk{}; + + // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) + auto new_active_Z = new_layout(Int<0>{}, tile.layout_fn()[active_Z]); + auto new_active_Y = new_layout(Int<0>{}, tile.layout_fn()[active_Y]); + + // Use this new swizzle identifier to construxt the new swizzle for new_layout + // (this also makes sure it's a "valid" swizzle that Swizzle can represent) + return composition(make_swizzle(), new_layout); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tiled_product(Layout const& block, + ComposedLayout const& tile) +{ + /// Avoid swizzle slice + auto result = logical_product(block, tile); + return composition(result.swizzle_fn(), result.offset_fn(), result.layout_fn()(_, repeat>(_))); +} + +template +CUTE_HOST_DEVICE constexpr +auto +blocked_product(Layout const& block, + ComposedLayout const& layout) +{ + constexpr int R = cute::max(rank_v, rank_v); + auto padded_block = append(block, Layout<_1,_0>{}); + auto padded_layout = append(layout, Layout<_1,_0>{}); + + auto result = logical_product(padded_block, padded_layout); + + return composition(result.swizzle_fn(), + result.offset_fn(), + coalesce(zip(get<0>(result.layout_fn()), get<1>(result.layout_fn())), repeat(Int<1>{}))); +} + +// +// Upcast and Downcast +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout const& layout) +{ + return composition(upcast(layout.swizzle_fn()), upcast(layout.offset_fn()), upcast(layout.layout_fn())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(ComposedLayout const& layout) +{ + return composition(downcast(layout.swizzle_fn()), downcast(layout.offset_fn()), downcast(layout.layout_fn())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast(ComposedLayout const& layout) +{ + if constexpr (sizeof(NewType) == sizeof(OldType)) { + return layout; + } else if constexpr (sizeof(NewType) > sizeof(OldType)) { + static_assert(sizeof(NewType) % sizeof(OldType) == 0, "NewType must be a multiple of OldType"); + return upcast(layout); + } else if constexpr (sizeof(NewType) < sizeof(OldType)) { + static_assert(sizeof(OldType) % sizeof(NewType) == 0, "NewType must be a divisor of OldType"); + return downcast(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(ComposedLayout const& layout) +{ + print(layout.swizzle_fn()); print(" o "); print(layout.offset_fn()); print(" o "); print(layout.layout_fn()); +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ComposedLayout const& layout) +{ + return os << layout.swizzle_fn() << " o " << layout.offset_fn() << " o " << layout.layout_fn(); +} + +} // end namespace cute diff --git a/include/cute/swizzle_ptr.hpp b/include/cute/swizzle_ptr.hpp new file mode 100644 index 0000000000..ed77acba75 --- /dev/null +++ b/include/cute/swizzle_ptr.hpp @@ -0,0 +1,282 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#include +#include +#include + +#include +#include +#include + +/* This implements a swizzle pointer of the form + * InvolutionFn o PtrAdd + * where the InvolutionFn need not be linear. + * + * This differs subtly from swizzle_layout because the smem pointer is used + * as the offset. That means that swizzle_layout will implement position-independent + * swizzle layouts, while swizzle_ptr implements position-dependent swizzle tensors. + * Arch chose to design hardware with position-dependent swizzles. + * + * For clarity: + * NormalLayout : DeRef <- PtrAdd <- [Layout] + * ComposedLayout: DeRef <- PtrAdd <- [Swizzle <- OffsetAdd <- Layout] + * SwizzlePtr : [DeRef <- Swizzle <- PtrAdd] <- Layout + * + * Furthermore, for known swizzles, this pointer attempts to decay itself + * to a normal-pointer with a new layout containing dynamic or static strides. + * This is possible by determining the subdomain of the InvolutionFn + * that is identity and testing if the Layout's codomain is contained + * within it. + */ + +namespace cute +{ + +template +struct smem_ptr_swizzle +{ + static_assert(std::is_empty::value, "Swizzle can't have state."); + + CUTE_HOST_DEVICE constexpr + T* get() const + { + return ptr_; + } + + CUTE_HOST_DEVICE constexpr static + Swizzle get_swizzle() + { + return {}; + } + + CUTE_HOST_DEVICE constexpr static + T* apply_swizzle(T* ptr) + { + return reinterpret_cast(Swizzle::apply(reinterpret_cast(ptr))); + } + + CUTE_HOST_DEVICE constexpr + T& operator*() const + { + return *apply_swizzle(get()); + } + + template + CUTE_HOST_DEVICE constexpr + T& operator[](Int const& i) const + { + return *apply_swizzle(get() + i); + } + + template + CUTE_HOST_DEVICE constexpr + smem_ptr_swizzle operator+(Int const& i) const + { + return {ptr_ + i}; + } + + T* ptr_; +}; + +template +struct is_smem> : true_type {}; + +// Make a swizzle pointer +template +CUTE_HOST_DEVICE constexpr +auto +make_smem_ptr(T* ptr, Swizzle const& swizzle) +{ + return smem_ptr_swizzle{ptr}; +} + +// A model of a nullptr smem_ptr with B == sizeof_bits::value +// That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr +template +struct smem_ptr_flag_bits : Int<0> {}; + +using smem_ptr_flag = smem_ptr_flag_bits<1>; + +// A flagged construction method to transform ComposedLayout +// Make a swizzle pointer tensor and check that the intended type size matches +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(smem_ptr const& ptr, + ComposedLayout,Layout> const& layout) +{ + static_assert(B == sizeof_bits::value, "Expected a B-bit pointer type."); + return make_tensor(make_smem_ptr(ptr.get(), layout.swizzle_fn()), + layout.layout_fn()); +} + +// Specialization for immediate decay +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(smem_ptr_swizzle>& p, Layout const& layout) +{ + return make_tensor(make_smem_ptr(p.ptr_), layout); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(smem_ptr_swizzle> const& p, Layout const& layout) +{ + return make_tensor(make_smem_ptr(p.ptr_), layout); +} + +// NOTE: To preserve smem_ptr_flag_bits under recast ops +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Layout> const& layout) +{ + return composition(layout.swizzle_fn(), smem_ptr_flag_bits{}, upcast(layout.layout_fn())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(ComposedLayout,Layout> const& layout) +{ + return composition(layout.swizzle_fn(), smem_ptr_flag_bits{}, downcast(layout.layout_fn())); +} + +// +// Recast +// Swizzle operates on the pointer address, so it doesn't care about the type +// + +template +CUTE_HOST_DEVICE constexpr +auto +recast(smem_ptr_swizzle const& ptr) +{ + return smem_ptr_swizzle{recast(ptr.ptr_)}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast(smem_ptr_swizzle const& ptr) +{ + return smem_ptr_swizzle{recast(ptr.ptr_)}; +} + +// +// Conversion with swizzle_layout +// + +template +CUTE_HOST_DEVICE +auto +as_position_independent_swizzle_layout(ComposedLayout,Layout> const& layout) +{ + return composition(recast,uint_bit_t>(layout.swizzle_fn()), Int<0>{}, layout.layout_fn()); +} + +template +CUTE_HOST_DEVICE +auto +as_position_independent_swizzle_tensor(Tensor>, Layout> const& tensor) +{ + { + uint32_t address = cast_smem_ptr_to_uint(tensor.data().get()); + uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code); + assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle + } + auto new_swizzle = recast,uint_bit_t>>(tensor.data().get_swizzle()); + return make_tensor(make_smem_ptr(tensor.data().get()), composition(new_swizzle, Int<0>{}, tensor.layout())); +} + +template +CUTE_HOST_DEVICE +auto +as_position_independent_swizzle_tensor(Tensor>, Layout>& tensor) +{ + { + uint32_t address = cast_smem_ptr_to_uint(tensor.data().get()); + uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code); + assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle + } + auto new_swizzle = recast,uint_bit_t>>(tensor.data().get_swizzle()); + return make_tensor(make_smem_ptr(tensor.data().get()), composition(new_swizzle, Int<0>{}, tensor.layout())); +} + +template +CUTE_HOST_DEVICE +auto +as_position_independent_swizzle_tensor(Tensor>, Layout>&& tensor) +{ + return as_position_independent_swizzle_tensor(tensor); +} + +// +// Print +// + +// Capture and cast smem_ptr_flag Layouts to offset-0 layouts +template +CUTE_HOST_DEVICE +void +print_latex(ComposedLayout,Layout> const& layout) +{ + auto new_swizzle = recast,uint_bit_t>(layout.swizzle_fn()); + print_latex(composition(new_swizzle, Int<0>{}, layout.layout_fn())); +} + +template +CUTE_HOST_DEVICE void print(smem_ptr_flag_bits const& ptr) +{ + printf("smem_ptr_%db(unset)", B); +} + +template +CUTE_HOST_DEVICE void print(smem_ptr_swizzle> const& ptr) +{ + printf("smem_ptr_S<%d,%d,%d>_%db(%p)", B, M, S, int(8*sizeof(T)), ptr.get()); +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr_swizzle> const&) +{ + return os << "smem_ptr_S<" << B << "," << M << "," << S << ">_" << int(8*sizeof(T)) << "b"; +} + +} // end namespace cute diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp new file mode 100644 index 0000000000..e88c22bcb7 --- /dev/null +++ b/include/cute/tensor.hpp @@ -0,0 +1,900 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace cute +{ + +// +// Engine -- owning or non-owning data store +// + +// concept Engine { +// using value_type = ; +// iterator begin(); +// }; + +template +using ArrayEngine = typename std::conditional<(sizeof_bits::value % 8 == 0), + array_aligned, + array_subbyte>::type; + +template +struct ViewEngine +{ + using value_type = typename cute::remove_cvref())>::type; + + using iterator = Iterator; + iterator storage_; + + CUTE_HOST_DEVICE constexpr + iterator const& + begin() const { + return storage_; + } + + CUTE_HOST_DEVICE constexpr + iterator& + begin() { + return storage_; + } +}; + +template +struct is_rmem> : is_rmem {}; +template +struct is_smem> : is_smem {}; +template +struct is_gmem> : is_gmem {}; +template +struct ConstViewEngine +{ + using value_type = typename cute::remove_cvref())>::type; + + using iterator = Iterator; + iterator storage_; + + CUTE_HOST_DEVICE constexpr + iterator const& + begin() const { + return storage_; + } +}; + +template +struct is_rmem> : is_rmem {}; +template +struct is_smem> : is_smem {}; +template +struct is_gmem> : is_gmem {}; +// +// Tensor +// + +template +struct Tensor +{ + using value_type = typename Engine::value_type; + //using pointer = typename engine_traits::pointer; + //using const_pointer = typename engine_traits::const_pointer; + //using reference = typename engine_traits::reference; + //using const_reference = typename engine_traits::const_reference; + + using engine_type = Engine; + using layout_type = Layout; + + CUTE_HOST_DEVICE constexpr + Tensor() {} + + template + CUTE_HOST_DEVICE constexpr + Tensor(Ptr const& ptr, Layout const& layout) + : rep_(layout, ptr) { + } + + // + // Accessors + // + + static constexpr int rank = Layout::rank; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + tensor() const { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return get<0>(rep_); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + engine() const { + return get<1>(rep_); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + engine() { + return get<1>(rep_); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + data() const { + return engine().begin(); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + data() { + return engine().begin(); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return layout().shape(); + } + + CUTE_HOST_DEVICE constexpr + auto + size() const { + return cute::size(shape()); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const { + return layout().stride(); + } + + // + // Indexing op() and op[] + // + + // Index into this tensor like an array by computing the offset via layout() + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator[](Coord const& coord) { + return data()[layout()(coord)]; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator[](Coord const& coord) const { + return data()[layout()(coord)]; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord const& coord) { + if constexpr (has_underscore::value) { + auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); + return make_tensor(data() + offset, sliced_layout); + } else { + return data()[layout()(coord)]; + } + + CUTE_GCC_UNREACHABLE; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); + return make_tensor(data() + offset, sliced_layout); + } else { + return data()[layout()(coord)]; + } + + CUTE_GCC_UNREACHABLE; + } + + // op() convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) { + return operator()(make_coord(c0,c1,cs...)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) { + return make_tensor(data(), layout().compose(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return make_tensor(data(), layout().compose(layouts...)); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) { + return make_tensor(data(), layout().tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return make_tensor(data(), layout().tile(layouts...)); + } + + // + // Utility + // + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_1d_coord(Int const& linear_idx) const { + return layout().get_1d_coord(linear_idx); + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_hier_coord(Int const& linear_idx) const { + return layout().get_hier_coord(linear_idx); + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_flat_coord(Int const& linear_idx) const { + return layout().get_flat_coord(linear_idx); + } + + cute::tuple rep_; +}; + + +template +struct is_tensor : false_type {}; +template +struct is_tensor> : true_type {}; + +template +struct is_rmem> : is_rmem {}; +template +struct is_smem> : is_smem {}; +template +struct is_gmem> : is_gmem {}; +// +// Make an owning Tensor that will allocate a static array +// + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Layout const& layout) +{ + static_assert(is_static::value, "Dynamic owning tensors not supported"); + using Engine = ArrayEngine>; + return Tensor(); +} + +// e.g. make_tensor(12) +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +make_tensor(LayoutArg const& arg, LayoutArgs const&... args) +{ + return make_tensor(make_layout(arg, args...)); +} + +// +// Make a non-owning Tensor that will use a pointer (view) +// + +template ::value && + is_layout::value)> +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Iterator const& iter, Layout const& layout) +{ + using Engine = ViewEngine; + return Tensor(iter, layout); +} + +// e.g. make_tensor(vec.data(), 12) +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Iterator const& iter, LayoutArg const& arg, LayoutArgs const&... args) +{ + return make_tensor(iter, make_layout(arg, args...)); +} + +// +// make_tensor_like -- make a register tensor the same type and shape as another +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor_like(Tensor const& tensor) +{ + using value_type = typename Tensor::value_type; + return make_tensor(tensor.shape()); +} + +// +// make_fragment_like -- make a register tensor the same type, shape, and (if possible) order as another tensor +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Tensor const& tensor) +{ + using value_type = typename Tensor::value_type; + return make_tensor(make_layout_like(tensor.layout())); +} + +// +// make_identity_tensor +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_identity_tensor(Shape const& shape) +{ + return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat_like(shape, Int<0>{}))), + make_identity_layout(shape)); +} + +// +// Utilities +// + +// Return the subtensor of a mode +template >::value)> +CUTE_HOST_DEVICE constexpr +decltype(auto) +tensor(Tensor&& tensor) +{ + return std::forward(tensor); +} + +template >::value)> +CUTE_HOST_DEVICE constexpr +decltype(auto) +tensor(Tensor&& tensor) +{ + return make_tensor(std::forward(tensor).data(), get(tensor.layout())); +} + +// Return the subtensor of a range of modes +template >::value)> +CUTE_HOST_DEVICE constexpr +decltype(auto) +take(Tensor&& tensor) +{ + return make_tensor(std::forward(tensor).data(), take(tensor.layout())); +} + +// Return the layout of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +layout(Tensor const& tensor) +{ + return layout(tensor.layout()); +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(Tensor const& tensor) +{ + return shape(tensor.layout()); +} + +// Return the stride of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(Tensor const& tensor) +{ + return stride(tensor.layout()); +} + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +size(Tensor const& tensor) +{ + return size(tensor.layout()); +} + +// Return the rank of a mode +template +CUTE_HOST_DEVICE constexpr +auto +rank(Tensor const& tensor) +{ + return rank(tensor.layout()); +} + +// Return the depth of a mode +template +CUTE_HOST_DEVICE constexpr +auto +depth(Tensor const& tensor) +{ + return depth(tensor.layout()); +} + +// +// Operations to manipulate Tensors like a Layout +// + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +flatten(Tensor&& tensor) +{ + return make_tensor(std::forward(tensor).data(), flatten(tensor.layout())); +} + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +coalesce(Tensor&& tensor) +{ + return make_tensor(std::forward(tensor).data(), coalesce(tensor.layout())); +} + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +coalesce(Tensor&& tensor, Profile const& profile) +{ + return make_tensor(std::forward(tensor).data(), coalesce(tensor.layout(), profile)); +} + +// Group the modes [B,E) into a single mode +// e.g. group<2,4>(make_tensor(Layout>{})) +// => make_tensor(Layout,_5,_6>>{}) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +group_modes(Tensor&& tensor) +{ + return make_tensor(std::forward(tensor).data(), + group(tensor.layout())); +} + +// +// Recast +// + +// NOTE: This is very dangerous to do +// -- doesn't check dynamic integer divisibility +// -- doesn't check alignment + +// A tagged version for dispatching +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +recast(Tensor&& tensor, type_list) +{ + using OldType = typename remove_cvref_t::value_type; + auto old_layout = tensor.layout(); + auto new_layout = recast(old_layout); + + // If this is an upcast of a normal Layout with static negative strides, then offset as well + if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout::value) { + auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{}); + auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{}); + auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); }); + + return make_tensor(recast(std::forward(tensor).data() + offset), new_layout); + } else { + return make_tensor(recast(std::forward(tensor).data() ), new_layout); + } + + CUTE_GCC_UNREACHABLE; +} + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +recast(Tensor&& tensor) +{ + return recast(std::forward(tensor), type_list{}); +} + +// +// max_common_vector +// + +/* Return Int such that N is the maximum number of continguous elements + * that logically correspond in the tensors of @a a and @a b. This is, + * the number of elements that could reasonably be vectorized into a single load/store. + * + * @returns Int with N >= 0 + * + * A return value of Int<0> indicates that no such conclusion can be made and no + * vectorization should be attempted. + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(Tensor const& a, + Tensor const& b) +{ + using SrcType = typename Tensor::value_type; + using DstType = typename Tensor::value_type; + + using SrcRef = decltype(*(a.data())); + using DstRef = decltype(*(b.data())); + + // Determine if vectorization candidates at all + if constexpr (// Should be the same value_types, else the copy is also performing a cast + sizeof(SrcType) == sizeof(DstType) && + // The types should be trivially copyable so that vectorization is valid + std::is_trivially_copyable::value && + std::is_trivially_copyable::value && + // Should be load/storing real data, rather than implicit iterators or such + std::is_reference::value && + std::is_reference::value) + { + return max_common_vector(a.layout(), b.layout()); + } else { + return Int<0>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Key algebraic operations +// + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +logical_divide(Tensor && tensor, + Tile const& tile) +{ + return make_tensor(std::forward(tensor).data(), + logical_divide(tensor.layout(), tile)); +} + +// zipped_divide is logical_divide with modes gathered into standard form ((BLK_A,BLK_B),(a,b)) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(Tensor && tensor, + Tile const& tile) // Layout or Tile +{ + return make_tensor(std::forward(tensor).data(), + zipped_divide(tensor.layout(), tile)); +} + +// tiled_divide is logical_divide with the second output mode flattened ((BLK_A,BLK_B),a,b) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(Tensor && tensor, + Tile const& tile) // Layout or Tile +{ + return make_tensor(std::forward(tensor).data(), + tiled_divide(tensor.layout(), tile)); +} + +// logical_product on a Tensor doesn't make sense since it often increases cosize + +// +// Logicial Divide utilities: local_partition and local_tile +// + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +local_partition(Tensor && tensor, + Tile const& tile, + Coord const& coord) +{ + constexpr int R1 = decltype(rank(tensor))::value; + + // Split the modes of tensor according to the modes of tile + // zipped_divide returns something like ((VEC_A,VEC_B,...),(a,b,...)) + + // The_coord is the coord into the first mode, flatten the rest + return zipped_divide(std::forward(tensor), tile)(coord, repeat(_)); +} + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +local_partition(Tensor && tensor, + Tile const& tile, + Coord const& coord, + Projection const& proj) +{ + return local_partition(std::forward(tensor), + dice(proj, tile), + dice(proj, coord)); +} + +// Special case with Layout and Integral that extracts the coord first +// e.g. local_partition(tensor, ThrLayout, threadIdx.x) +template >::value && + is_integral::value)> +CUTE_HOST_DEVICE +auto +local_partition(Tensor && tensor, + Layout const& tile, + Index const& index) +{ + return local_partition(std::forward(tensor), + product_each(shape(tile)), + tile.get_flat_coord(index)); +} + +// Special case with Layout and Integral that extracts the coord first +// e.g. local_partition(tensor, ThrLayout, threadIdx.x, Step<_1,X,_1>{}) +template >::value && + is_integral::value)> +CUTE_HOST_DEVICE +auto +local_partition(Tensor && tensor, + Layout const& tile, + Index const& index, + Projection const& proj) +{ + return local_partition(std::forward(tensor), + dice(proj, product_each(shape(tile))), + dice(proj, tile).get_flat_coord(index)); +} + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +local_tile(Tensor && tensor, + Tile const& tile, + Coord const& coord) +{ + constexpr int R0 = decltype(rank(tile))::value; + constexpr int R1 = decltype(rank(tensor))::value; + + // Split the modes of tensor according to the modes of tile + // zipped_divide returns something like ((VEC_A,VEC_B,...),(a,b,...)) + + // The padded_coord is the coord into the second mode, flatten the rest + return zipped_divide(std::forward(tensor), tile)(repeat(_), append(coord,_)); +} + +template >::value)> +CUTE_HOST_DEVICE +auto +local_tile(Tensor && tensor, + Tile const& tile, + Coord const& coord, + Proj const& proj) +{ + return local_tile(std::forward(tensor), + dice(proj, tile), + dice(proj, coord)); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor) +{ + auto format = get_format(tensor(0)); + using type = typename decltype(format)::type; + + if constexpr (Layout::rank == 1) + { + for (int m = 0; m < size(tensor); ++m) { + printf(format.format, format.digits, type(tensor(m))); + printf("\n"); + } + } else + if constexpr (Layout::rank == 2) + { + for (int m = 0; m < size<0>(tensor); ++m) { + for (int n = 0; n < size<1>(tensor); ++n) { + printf(format.format, format.digits, type(tensor(m,n))); + } + printf("\n"); + } + } else + if constexpr (Layout::rank == 3) + { + print_tensor(tensor(_,_,0)); + for (int k = 1; k < size<2>(tensor); ++k) { + for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("-"); } print("\n"); + print_tensor(tensor(_,_,k)); + } + } else + if constexpr (Layout::rank == 4) + { + print_tensor(tensor(_,_,_,0)); + for (int p = 1; p < size<3>(tensor); ++p) { + for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("="); } print("\n"); + print_tensor(tensor(_,_,_,p)); + } + } +} + +template +CUTE_HOST_DEVICE void print(Tensor const& tensor) +{ + print(tensor.layout()); print("\n"); + print_tensor(tensor); +} + +template +CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor const& tensor) +{ + int digits = 9; + + if constexpr (Layout::rank == 1) + { + for (int m = 0; m < size(tensor); ++m) { + os << std::setw(digits) << tensor(m) << std::endl; + } + } else + if constexpr (Layout::rank == 2) + { + for (int m = 0; m < size<0>(tensor); ++m) { + for (int n = 0; n < size<1>(tensor); ++n) { + os << std::setw(digits) << tensor(m,n); + } + os << std::endl; + } + } else + if constexpr (Layout::rank == 3) + { + print_tensor_os(os, tensor(_,_,0)); + for (int k = 1; k < size<2>(tensor); ++k) { + for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl; + print_tensor_os(os, tensor(_,_,k)); + } + } else + if constexpr (Layout::rank == 4) + { + print_tensor_os(os, tensor(_,_,_,0)); + for (int p = 1; p < size<3>(tensor); ++p) { + for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl; + print_tensor_os(os, tensor(_,_,_,p)); + } + } + + return os; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const& tensor) +{ + os << tensor.layout() << std::endl; + return print_tensor_os(os, tensor); +} + +} // end namespace cute + +// +// Extended Engines +// + +#include + +// +// Tensor Algorithms +// + +#include +#include +#include +#include +#include +#include diff --git a/include/cute/tensor_predicate.hpp b/include/cute/tensor_predicate.hpp new file mode 100644 index 0000000000..730f219462 --- /dev/null +++ b/include/cute/tensor_predicate.hpp @@ -0,0 +1,63 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +template +struct ConstantTensor +{ + template + CUTE_HOST_DEVICE constexpr + T const& + operator()(Coords const&...) const { + return val_; + } + + T val_; +}; + +struct TrivialPredTensor +{ + template + CUTE_HOST_DEVICE constexpr + true_type + operator()(Coords const&...) const { + return {}; + } +}; + +} // end namespace cute diff --git a/include/cute/tile.hpp b/include/cute/tile.hpp new file mode 100644 index 0000000000..b2fa2e8b7b --- /dev/null +++ b/include/cute/tile.hpp @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +// +// A Tile is not a Layout, it's a tuple of Layouts or Tiles or Underscores +// + +template +using Tile = tuple; + +template +using is_tile = is_tuple; + +template +CUTE_HOST_DEVICE constexpr +auto +make_tile(Layouts const&... layouts) +{ + return Tile(layouts...); +} + +} // end namespace cute diff --git a/include/cute/underscore.hpp b/include/cute/underscore.hpp new file mode 100644 index 0000000000..d79b4ee8c4 --- /dev/null +++ b/include/cute/underscore.hpp @@ -0,0 +1,148 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include + +namespace cute +{ + +// For slicing +struct Underscore : Int<0> {}; + +CUTE_INLINE_CONSTANT Underscore _; + +// Treat Underscore as an integral like integral_constant +template <> +struct is_integral : true_type {}; + +template +struct is_underscore : false_type {}; +template <> +struct is_underscore : true_type {}; + +// Tuple trait for detecting static member element +template +struct has_elem : false_type {}; +template +struct has_elem : true_type {}; +template +struct has_elem::value> > + : has_elem > {}; +template +struct has_elem> + : disjunction, Elem>...> {}; + +// Tuple trait for detecting static member element +template +struct all_elem : false_type {}; +template +struct all_elem : true_type {}; +template +struct all_elem::value> > + : all_elem > {}; +template +struct all_elem> + : conjunction, Elem>...> {}; + +// Tuple trait for detecting Underscore member +template +using has_underscore = has_elem; + +template +using all_underscore = all_elem; + +template +using has_int1 = has_elem>; + +template +using has_int0 = has_elem>; + +// +// Slice keeps only the elements of Tuple B that are paired with an Underscore +// + +template +CUTE_HOST_DEVICE constexpr +auto +slice(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return filter_tuple(a, b, [](auto const& x, auto const& y) { return slice(x,y); }); + } else if constexpr (is_underscore::value) { + return cute::tuple{b}; + } else { + return cute::tuple<>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Dice keeps only the elements of Tuple B that are paired with an Int +// + +template +CUTE_HOST_DEVICE constexpr +auto +dice(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return filter_tuple(a, b, [](auto const& x, auto const& y) { return dice(x,y); }); + } else if constexpr (is_underscore::value) { + return cute::tuple<>{}; + } else { + return cute::tuple{b}; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Display utilities +// + +CUTE_HOST_DEVICE void print(Underscore const&) { + printf("_"); +} + +CUTE_HOST std::ostream& operator<<(std::ostream& os, Underscore const&) { + return os << "_"; +} + +} // end namespace cute diff --git a/include/cute/util/debug.hpp b/include/cute/util/debug.hpp new file mode 100644 index 0000000000..9a62143c95 --- /dev/null +++ b/include/cute/util/debug.hpp @@ -0,0 +1,153 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/** + * \file + * \brief Debugging and logging functionality + */ + +#include + +#include + +namespace cute +{ + +/****************************************************************************** + * Debug and logging macros + ******************************************************************************/ + +/** + * Formats and prints the given message to stdout + */ +#if !defined(CUTE_LOG) +# if !defined(__CUDA_ARCH__) +# define CUTE_LOG(format, ...) printf(format, __VA_ARGS__) +# else +# define CUTE_LOG(format, ...) \ + printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ + blockIdx.x, blockIdx.y, blockIdx.z, \ + threadIdx.x, threadIdx.y, threadIdx.z, \ + __VA_ARGS__); +# endif +#endif + +/** + * Formats and prints the given message to stdout only if DEBUG is defined + */ +#if !defined(CUTE_LOG_DEBUG) +# ifdef DEBUG +# define CUTE_LOG_DEBUG(format, ...) CUTE_LOG(format, __VA_ARGS__) +# else +# define CUTE_LOG_DEBUG(format, ...) +# endif +#endif + +/** + * \brief Perror macro with exit + */ +#if !defined(CUTE_ERROR_EXIT) +# define CUTE_ERROR_EXIT(e) \ + do { \ + cudaError_t code = (e); \ + if (code != cudaSuccess) { \ + fprintf(stderr, "<%s:%d> %s:\n %s: %s\n", \ + __FILE__, __LINE__, #e, \ + cudaGetErrorName(code), cudaGetErrorString(code)); \ + fflush(stderr); \ + exit(0); \ + } \ + } while (0) +#endif + +#if !defined(CUTE_CHECK_LAST) +# define CUTE_CHECK_LAST() CUTE_ERROR_EXIT(cudaPeekAtLastError()); CUTE_ERROR_EXIT(cudaDeviceSynchronize()) +#endif + +#if !defined(CUTE_CHECK_ERROR) +# define CUTE_CHECK_ERROR(e) CUTE_ERROR_EXIT(e) +#endif + +// A dummy function that uses compilation failure to print a type +template +CUTE_HOST_DEVICE +void +print_type(T&&) { + static_assert(sizeof(T) < 0, "Printing type T."); +} + +// +// Device-specific helpers +// +// e.g. +// if (thread0()) print(...); +// if (block0()) print(...); +// if (thread(42)) print(...); + +CUTE_HOST_DEVICE +bool +thread(int tid, int bid) +{ +#if defined(__CUDA_ARCH__) + return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid) + && ( blockIdx.x + blockIdx.y* gridDim.x + blockIdx.z* gridDim.x* gridDim.y == bid); +#else + return true; +#endif +} + +CUTE_HOST_DEVICE +bool +thread(int tid) +{ + return thread(tid, 0); +} + +CUTE_HOST_DEVICE +bool +thread0() +{ + return thread(0,0); +} + +CUTE_HOST_DEVICE +bool +block0() +{ +#if defined(__CUDA_ARCH__) + return !(blockIdx.x | blockIdx.y | blockIdx.z); +#else + return true; +#endif +} + +} // end namespace cute diff --git a/include/cute/util/print.hpp b/include/cute/util/print.hpp new file mode 100644 index 0000000000..ec774b00ff --- /dev/null +++ b/include/cute/util/print.hpp @@ -0,0 +1,140 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// +// CUDA compatible print and printf +// + +namespace cute +{ + +CUTE_HOST_DEVICE +int +num_digits(int x) +{ + return (x < 10 ? 1 : + (x < 100 ? 2 : + (x < 1000 ? 3 : + (x < 10000 ? 4 : + (x < 100000 ? 5 : + (x < 1000000 ? 6 : + (x < 10000000 ? 7 : + (x < 100000000 ? 8 : + (x < 1000000000 ? 9 : + 10))))))))); +} + +template +struct format_and_size { + using type = T; + char const* format; + int digits; +}; + +CUTE_HOST_DEVICE +format_and_size +get_format(bool) { + return {"%*d", 3}; +} + +CUTE_HOST_DEVICE +format_and_size +get_format(int32_t) { + return {"%*d", 5}; +} + +CUTE_HOST_DEVICE +format_and_size +get_format(uint32_t) { + return {"%*d", 5}; +} + +CUTE_HOST_DEVICE +format_and_size +get_format(int64_t) { + return {"%*d", 5}; +} + +CUTE_HOST_DEVICE +format_and_size +get_format(uint64_t) { + return {"%*d", 5}; +} + +CUTE_HOST_DEVICE +format_and_size +get_format(half_t) { + return {"%*.2f", 8}; +} + +CUTE_HOST_DEVICE +format_and_size +get_format(float) { + return {"%*.2e", 10}; +} + +CUTE_HOST_DEVICE +format_and_size +get_format(double) { + return {"%*.3e", 11}; +} + +// +// print dispatcher +// + +CUTE_HOST_DEVICE +void +print(char const& c) { + printf("%c", c); +} + +template ::value)> +CUTE_HOST_DEVICE +void +print(T const& a) { + printf("%d", int(a)); +} + +template +CUTE_HOST_DEVICE +void +print(char const* format, T const&... t) { + printf(format, t...); +} + +} // end namespace cute diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp new file mode 100644 index 0000000000..4d37eb9e48 --- /dev/null +++ b/include/cute/util/type_traits.hpp @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#define __CUTE_REQUIRES(...) typename std::enable_if<(__VA_ARGS__)>::type* = nullptr +#define __CUTE_REQUIRES_V(...) typename std::enable_if::type* = nullptr + +namespace cute +{ + +using std::conjunction; +using std::conjunction_v; + +using std::disjunction; +using std::disjunction_v; + +using std::negation; +using std::negation_v; + +using std::void_t; + +// C++20 +// using std::remove_cvref; +template +struct remove_cvref { + using type = std::remove_cv_t>; +}; + +// C++20 +// using std::remove_cvref_t; +template +using remove_cvref_t = typename remove_cvref::type; + +// +// is_valid +// + +namespace detail { + +template ()(std::declval()...))> +CUTE_HOST_DEVICE constexpr auto +is_valid_impl(int) { return std::true_type{}; } + +template +CUTE_HOST_DEVICE constexpr auto +is_valid_impl(...) { return std::false_type{}; } + +template +struct is_valid_fn { + template + CUTE_HOST_DEVICE constexpr auto + operator()(Args&&...) const { return is_valid_impl(int{}); } +}; + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr auto +is_valid(F&&) { + return detail::is_valid_fn{}; +} + +template +CUTE_HOST_DEVICE constexpr auto +is_valid(F&&, Args&&...) { + return detail::is_valid_impl(int{}); +} + +} // end namespace cute diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h new file mode 100644 index 0000000000..34f0b4ee72 --- /dev/null +++ b/include/cutlass/arch/barrier.h @@ -0,0 +1,404 @@ +/*************************************************************************************************** + * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are not permit- + * ted. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Barrier Operations on SM90+ +*/ + +#pragma once + +#include +#include + +namespace cutlass { +/// @brief +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12) +#define CUDA_BARRIER_ENABLED 1 +#else +#define CUDA_BARRIER_ENABLED 0 +#endif + +class NamedBarrier { + + // Data Members: + + // Range = [1 , NUM_THREADS_PER_CTA] + // Range % warp-size (i.e 32) == 0 + uint32_t const num_threads_; + + // Range : [0, 15] + uint32_t const id_; + + public: + + CUTLASS_DEVICE + NamedBarrier(uint32_t num_threads, uint32_t id = 0) + : num_threads_(num_threads), id_(id) {} + + CUTLASS_DEVICE + void arrive_and_wait() const { + NamedBarrier::arrive_and_wait(num_threads_, id_); + } + + CUTLASS_DEVICE + void arrive() const { + NamedBarrier::arrive(num_threads_, id_); + } + + CUTLASS_DEVICE + void sync() const { + NamedBarrier::arrive_and_wait(); + } + + // Static variants + CUTLASS_DEVICE + static void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +#else + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void arrive(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +#else + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void sync(uint32_t num_threads, uint32_t barrier_id) { + NamedBarrier::arrive_and_wait(num_threads, barrier_id); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Hopper introduces a new cluster-wide barrier which handle with Cluster-wide AW behaviour. +// This is an extension to the Ampere AW barriers +// Note : Ampere AW Barriers have a larger max-arrive count (2^30) than Hopper AW Barriers (2^20). +struct ClusterBarrier { + + using ValueType = uint64_t; + +protected: + // Can never be initializated - can only be aliased to smem + ValueType barrier_; + +public: + + CUTLASS_DEVICE + ClusterBarrier() = delete; + + CUTLASS_DEVICE + void init(uint32_t arrive_count) const { + ClusterBarrier::init(&this->barrier_, arrive_count); + } + + CUTLASS_DEVICE + uint32_t test_wait(uint32_t phase, uint32_t pred=true) const { + return ClusterBarrier::test_wait(&this->barrier_, phase, pred); + } + + CUTLASS_DEVICE + void wait(uint32_t phase) const { + ClusterBarrier::wait(&this->barrier_, phase); + } + + // Barrier arrive on local smem + CUTLASS_DEVICE + void arrive() const { + ClusterBarrier::arrive(&this->barrier_); + } + + // Remote SMEM arrive with a perdicate (usually done to pick the thread doing the arrive) + CUTLASS_DEVICE + void arrive(uint32_t cta_id, uint32_t pred = true ) const { + ClusterBarrier::arrive(&this->barrier_, cta_id, pred); + } + + // + // Static Versions + // + CUTLASS_DEVICE + static void init(ValueType const* smem_ptr, uint32_t arrive_count) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.init.shared.b64 [%1], %0; \n" + "}" + : + : "r"(arrive_count), "r"(smem_addr)); +#else + asm volatile ("brkpt;\n" ::); +#endif + } + + // Static version of wait - in case we don't want to burn a register + CUTLASS_DEVICE + static void wait(ValueType const* smem_ptr, uint32_t phase) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + // Arbitrarily large timer value after which try-wait expires and re-tries. + uint32_t ticks = 0x989680; + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra.uni DONE; \n\t" + "bra.uni LAB_WAIT; \n\t" + "DONE: \n\t" + "}" + : + : "r"(smem_addr), "r"(phase), "r"(ticks)); + +#else + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static uint32_t test_wait(ValueType const* smem_ptr, uint32_t phase, uint32_t pred) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t waitComplete; + + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + ".reg .pred P2; \n\t" + "setp.eq.u32 P2, %3, 1;\n\t" + "@P2 mbarrier.test_wait.parity.shared.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(waitComplete) + : "r"(smem_addr), "r"(phase), "r"(pred)); + + return waitComplete; +#else + asm volatile ("brkpt;\n" ::); +#endif + return 0; + } + + // Static Predicated version of the above - in case we know the address. + CUTLASS_DEVICE + static void arrive(ValueType const* smem_ptr, uint32_t cta_id, uint32_t pred) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b32 remAddr32;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" + "@p mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t" + "}" + : + : "r"(smem_addr), "r"(cta_id), "r"(pred)); +#else + asm volatile ("brkpt;\n" ::); +#endif + } + + // Barrier arrive on local smem + CUTLASS_DEVICE + static void arrive(ValueType const* smem_ptr) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint64_t state = 0; + asm volatile( + "{\n\t" + "mbarrier.arrive.shared.b64 %1, [%0];\n\t" + "}" + : + : "r"(smem_addr), "l"(state)); +#else + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void invalidate(ValueType const* smem_ptr) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.ival.shared.b64 [%0]; \n\t" + "}" + : + : "r"(smem_addr)); +#else + asm volatile ("brkpt;\n" ::); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 also introduces a new type of cluster-barrier which supports sync. +// not just based on Arrive Count, but also transaction count (in bytes) +struct ClusterTransactionBarrier : public ClusterBarrier { + + CUTLASS_DEVICE + ClusterTransactionBarrier() = delete; + + // Performs an arrive operation + bytes reset + CUTLASS_DEVICE + void arrive_and_reset_bytes(uint32_t transaction_bytes) const { + ClusterTransactionBarrier::arrive_and_reset_bytes(&this->barrier_, transaction_bytes); + } + + // Performs an arrive operation + bytes reset + CUTLASS_DEVICE + void arrive_and_reset_bytes(uint32_t transaction_bytes, uint32_t cta_id) const { + ClusterTransactionBarrier::arrive_and_reset_bytes(&this->barrier_, transaction_bytes , cta_id, true); + } + + CUTLASS_DEVICE + void commit(uint32_t transaction_bytes, uint32_t pred = 1) const { + uint32_t cta_rank = cute::block_rank_in_cluster(); + ClusterTransactionBarrier::commit(&this->barrier_, cta_rank, transaction_bytes, pred); + } + + CUTLASS_DEVICE + void commit(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const { + ClusterTransactionBarrier::commit(&this->barrier_, dst_cta_id, transaction_bytes, pred); + } + + // + // Static Versions + // + + // Performs an arrive operation + bytes reset + CUTLASS_DEVICE + static void arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0; \n\t" + "}" + : + : "r"(transaction_bytes), "r"(smem_addr)); +#else + asm volatile ("brkpt;\n" ::); +#endif + } + + // Performs an arrive operation + bytes reset for a remote cta_id in a Cluster + CUTLASS_DEVICE + static void arrive_and_reset_bytes( + ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b32 remAddr32;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" + "@p mbarrier.arrive.expect_tx.shared::cluster.b64 _, [remAddr32], %3;\n\t" + "}" + : + : "r"(smem_addr), "r"(cta_id), "r"(pred), "r"(transaction_bytes)); +#else + asm volatile ("brkpt;\n" ::); +#endif + } + + // Performs an bytes reset without doing an arrive operation + CUTLASS_DEVICE + static void reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.expect_tx.shared.b64 [%1], %0; \n\t" + "}" + : + : "r"(transaction_bytes), "r"(smem_addr)); +#else + asm volatile ("brkpt;\n" ::); +#endif + } + + // Increments transaction bytes in the barrier + CUTLASS_DEVICE + static void commit( + ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred = 1) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + smem_addr = cute::set_block_rank(smem_addr, dst_cta_id); + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p mbarrier.complete_tx.shared::cluster.relaxed.cluster.b64 [%1], %0;" + "}" + : + : "r"(transaction_bytes), "r"(smem_addr), "r"(pred)); +#else + asm volatile ("brkpt;\n" ::); +#endif + } +}; + +// Helps with visibility of barrier init operations across warps / cta / cluster +// Available as a separate function so as to batch inits across barriers and fence once +// Note : It must be composed with an appropriate sync instruction with the right scope +// to ensure visibility eg. __syncthreads() or a cluster_arrive() + cluster_wait() +CUTLASS_DEVICE +void fence_barrier_init() { +#if CUDA_BARRIER_ENABLED + asm volatile( + "{\n\t" + "fence.mbarrier_init.release.cluster; \n" + "}" + ::); +#else + asm volatile ("brkpt;\n" ::); +#endif +} + +// Issue a shared memory fence for async operations +CUTLASS_DEVICE +void fence_view_async_shared() { +#if CUDA_BARRIER_ENABLED + asm volatile ( + "{\n\t" + "fence.proxy.async.shared::cta; \n" + "}" + ::); +#else + asm volatile ("brkpt;\n" ::); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // end namespace arch +} // end namespace cutlass diff --git a/include/cutlass/arch/memory_sm75.h b/include/cutlass/arch/memory_sm75.h index 5f45eb5858..ba59364f5e 100644 --- a/include/cutlass/arch/memory_sm75.h +++ b/include/cutlass/arch/memory_sm75.h @@ -36,6 +36,7 @@ #include "cutlass/array.h" #include "cutlass/layout/matrix.h" +#include "cute/arch/util.hpp" namespace cutlass { namespace arch { @@ -65,74 +66,13 @@ inline __device__ void ldsm(Array & D, void const* ptr); #define CUDA_LDMATRIX_SUPPORTED 1 #endif -///////////////////////////////////////////////////////////////////////////////////////////////// -/* -#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED) && (__CUDACC_VER_MAJOR__ > 10) - #define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED 1 -#endif -#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED) - #define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 1)) -#endif - -#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_ENABLED) - #define CUDA_NVVM_GET_SMEM_POINTER_ENABLED CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED -#endif -*/ - -#if (! defined (__clang__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) - extern "C" { - // - // This NVVM intrinsic is subject to change in future versions of CUDA. - // Clients should not call it directly. Rather, they should use the - // cutlass::arch::ldsm<>() template. - // - __device__ uint32_t __nvvm_get_smem_pointer(void *); - } -#endif - ///////////////////////////////////////////////////////////////////////////////////////////////// /// CUTLASS helper to get SMEM pointer inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) { - -// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to -// the previous internal intrinsics if they are available. -#if (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11) - // - // This NVVM intrinsic converts an address in shared memory to a plain - // unsigned integer. This is necessary to pass to shared memory instructions - // in inline PTX. - // - // In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2]. - // - //__device__ size_t __cvta_generic_to_shared(void* ptr); - - /// CUTLASS helper to get SMEM pointer - return static_cast(__cvta_generic_to_shared(ptr)); - -#elif (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) - - return __nvvm_get_smem_pointer(ptr); - -#elif defined(__CUDA_ARCH__) - - uint32_t smem_ptr; - - asm( - "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" - : "=r"(smem_ptr) : "l"(ptr)); - - return smem_ptr; - -#else - - CUTLASS_UNUSED(ptr); - CUTLASS_NOT_IMPLEMENTED(); - return 0; - -#endif + return cute::cast_smem_ptr_to_uint(ptr); } - + /// CUTLASS helper to get SMEM pointer inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) { return cutlass_get_smem_pointer(const_cast(ptr)); diff --git a/include/cutlass/arch/mma.h b/include/cutlass/arch/mma.h index 587ff8864f..7d4d693a09 100644 --- a/include/cutlass/arch/mma.h +++ b/include/cutlass/arch/mma.h @@ -224,5 +224,4 @@ struct SparseMma; #include "cutlass/arch/mma_sm80.h" #include "cutlass/arch/mma_sparse_sm80.h" #include "cutlass/arch/mma_sm90.h" - ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/mma_sm80.h b/include/cutlass/arch/mma_sm80.h index cb7debc8bc..8682ae1ba8 100644 --- a/include/cutlass/arch/mma_sm80.h +++ b/include/cutlass/arch/mma_sm80.h @@ -2166,7 +2166,7 @@ struct Mma< "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else - + CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); diff --git a/include/cutlass/arch/mma_sm90.h b/include/cutlass/arch/mma_sm90.h index 85e808a59d..1d0745b408 100644 --- a/include/cutlass/arch/mma_sm90.h +++ b/include/cutlass/arch/mma_sm90.h @@ -47,10 +47,21 @@ //////////////////////////////////////////////////////////////////////////////// #if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8)) -#define CUTLASS_ARCH_MMA_SM90_SUPPORTED 1 -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -#define CUTLASS_ARCH_MMA_SM90_ENABLED + #define CUTLASS_ARCH_MMA_SM90_F64_MMA_SUPPORTED + #if (!defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED)) + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + #define CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED + #endif + #endif #endif + +#if (__CUDACC_VER_MAJOR__ >= 12) + #define CUTLASS_ARCH_MMA_SM90_SUPPORTED + #if (!defined(CUTLASS_ARCH_MMA_SM90_ENABLED)) + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + #define CUTLASS_ARCH_MMA_SM90_ENABLED + #endif + #endif #endif //////////////////////////////////////////////////////////////////////////////// @@ -97,7 +108,7 @@ struct Mma< void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const { -#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) double const *A = reinterpret_cast(&a); double const *B = reinterpret_cast(&b); @@ -105,10 +116,73 @@ struct Mma< double const *C = reinterpret_cast(&c); double *D = reinterpret_cast(&d); - asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64.rn {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) - : "d"(A[0]), "d"(A[1]), - "d"(B[0]), + : "d"(A[0]), "d"(A[1]), + "d"(B[0]), + "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); + +#else + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Matrix Multiply-Add 16x8x8 fp64 +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<16,8,8>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,8>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm90; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) + + double const *A = reinterpret_cast(&a); + double const *B = reinterpret_cast(&b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=d"(D[0]), "=d"(d[1]), "=d"(d[2]), "=d"(d[3]) + : "d"(A[0]), "d"(A[1]), "d"(A[2]), "d"(A[3]), + "d"(B[0]), "d"(B[1]), "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); #else @@ -118,7 +192,65 @@ struct Mma< CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Matrix Multiply-Add 16x8x16 fp64 +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm90; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) + + double const *A = reinterpret_cast(&a); + double const *B = reinterpret_cast(&b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7, %8, %9, %10, %11}, {%12, %13, %14, %15}, {%16, %17, %18, %19};\n" + : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) + : "d"(A[0]), "d"(A[2]), "d"(A[2]), "d"(A[3]), "d"(A[4]), "d"(A[5]), "d"(A[6]), "d"(A[7]) + "d"(B[0]), "d"(B[1]), "d"(B[2]), "d"(B[3]), + "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); + +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -129,3 +261,4 @@ struct Mma< } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/arch/reg_reconfig.h b/include/cutlass/arch/reg_reconfig.h new file mode 100644 index 0000000000..2b74a22e6c --- /dev/null +++ b/include/cutlass/arch/reg_reconfig.h @@ -0,0 +1,68 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief PTX for CTA Reconfiguration +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) + #if (defined(__CUDA_ARCH_FEAT_SM90_ALL)) + #define CUDA_CTA_RECONFIG_ACTIVATED 1 + #endif +#else + #define CUDA_CTA_RECONFIG_ACTIVATED 0 +#endif + +namespace cutlass { +namespace arch { + +template +CUTLASS_DEVICE +void warpgroup_reg_alloc(){ +#if CUDA_CTA_RECONFIG_ACTIVATED + asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); +#endif +} + +template +CUTLASS_DEVICE +void warpgroup_reg_dealloc(){ +#if CUDA_CTA_RECONFIG_ACTIVATED + asm volatile( "setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); +#endif +} + +} // namespace arch +} // namespace cutlass diff --git a/include/cutlass/array_subbyte.h b/include/cutlass/array_subbyte.h index d3822da97c..ac30422408 100644 --- a/include/cutlass/array_subbyte.h +++ b/include/cutlass/array_subbyte.h @@ -370,8 +370,6 @@ class Array { CUTLASS_HOST_DEVICE reverse_iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } - - // TODO }; /// Bidirectional constant iterator over elements @@ -390,8 +388,6 @@ class Array { CUTLASS_HOST_DEVICE const_reverse_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } - - // TODO }; private: diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp new file mode 100644 index 0000000000..4843540752 --- /dev/null +++ b/include/cutlass/cluster_launch.hpp @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief PTX for TMA Tensor Memory Access operators on memory added for SM90 +*/ + +#pragma once + +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/trace.h" + +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) +# define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED +#endif + +namespace cutlass { + +#ifndef NDEBUG +#define Return_Status(cudaError_t_status) \ + if (cudaError_t_status != cudaSuccess) { \ + fprintf(stderr, \ + "[ ERROR: CUDA Runtime ] %s:%d: %s\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(cudaError_t_status)); \ + return Status::kInvalid; \ + } else { \ + return Status::kSuccess; \ + } +#else +#define Return_Status(cudaError_t_status) \ + if (cudaError_t_status != cudaSuccess) { \ + return Status::kInvalid; \ + } else { \ + return Status::kSuccess; \ + } +#endif + +struct ClusterLauncher { + constexpr static int MaxClusterSize = 32; + + // Check for hardware compatibility + static inline __host__ + Status check_cluster_dims(dim3 const& grid, dim3 const& cluster) { + if (((cluster.x * cluster.y * cluster.z) <= MaxClusterSize) && + (grid.x % cluster.x == 0) && (grid.y % cluster.y == 0) && (grid.z % cluster.z == 0)) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST("ClusterLauncher: Invalid cluster configuration -- aborting launch."); + return Status::kInvalid; + } + } + + static inline __host__ + Status +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + init(void const* kernel_function) +#else + init(void const* /* kernel_function */) +#endif + { +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + // This attribute was added in CUDA 11.8. + cudaError_t status = + cudaFuncSetAttribute( + kernel_function, cudaFuncAttributeNonPortableClusterSizeAllowed, 1); + Return_Status(status); +#else + return Status::kInvalid; +#endif + } + + // This is the method we expect to use going forward + static inline __host__ + Status launch( + dim3 const& grid_dims, + dim3 const& cluster_dims, + dim3 const& block_dims, + size_t const& smem_size, + cudaStream_t& cuda_stream, + void const* kernel, + void** kernel_params) { +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + if (check_cluster_dims(grid_dims, cluster_dims) != Status::kSuccess) { + CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting."); + return Status::kInvalid; + } + + auto init_status = init(kernel); + if (init_status != Status::kSuccess) { + CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting."); + return Status::kInvalid; + } + + cudaLaunchConfig_t launch_config; + launch_config.gridDim = {grid_dims.x, grid_dims.y, grid_dims.z}; + launch_config.blockDim = {block_dims.x, block_dims.y, block_dims.z}; + launch_config.dynamicSmemBytes = smem_size; + launch_config.stream = cuda_stream; + + cudaLaunchAttribute launch_attribute[1]; + launch_attribute[0].id = cudaLaunchAttributeClusterDimension; + launch_attribute[0].val.clusterDim.x = cluster_dims.x; + launch_attribute[0].val.clusterDim.y = cluster_dims.y; + launch_attribute[0].val.clusterDim.z = cluster_dims.z; + + launch_config.attrs = launch_attribute; + launch_config.numAttrs = 1; + + CUTLASS_TRACE_HOST("ClusterLauncher: Launching GPC_CLUSTER_GRID GridDims = " + "(" << grid_dims.x << ", " << grid_dims.y << ", " << grid_dims.z << "), " + "And ClusterDims = " + "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); + + cudaError_t status = cudaLaunchKernelExC(&launch_config, kernel, kernel_params); + Return_Status(status); +#else + CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch."); + return Status::kInvalid; +#endif + } +}; + +} // namespace cutlass diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution.h b/include/cutlass/conv/kernel/implicit_gemm_convolution.h index 2d2d249466..11ac967c65 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution.h @@ -332,7 +332,7 @@ struct ImplicitGemmConvolution { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h index d65a34e1e7..b740c9058f 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h @@ -339,7 +339,7 @@ struct ImplicitGemmConvolutionFusion { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h index 75d0338b86..7304cbdecb 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h @@ -335,7 +335,7 @@ struct ImplicitGemmConvolutionStridedDgrad { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; // Check if CTA contributes valid MMA (Dy * w) and accumulator will be non-zero after MMA diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h index 8c6013c5d5..3fa7daca1b 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h @@ -341,7 +341,7 @@ struct ImplicitGemmConvolutionWithFusedEpilogue { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index 1884443470..12bc3a3717 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -72,20 +72,20 @@ CUTLASS_HOST_DEVICE void __CUTLASS_UNUSED(T const &) #include -#if defined(_MSC_VER) - #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) -#else - #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) -#endif - -#else - -#if defined(_MSC_VER) - #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) -#else - #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) -#endif + #if defined(__CUDA_ARCH__) + #if defined(_MSC_VER) + #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __FUNCSIG__); asm volatile ("brkpt;\n"); } + #else + #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __PRETTY_FUNCTION__); asm volatile ("brkpt;\n"); } + #endif + #else + #if defined(_MSC_VER) + #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) + #else + #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) + #endif + #endif #endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -181,10 +181,11 @@ static char const* cutlassGetStatusString(cutlass::Status status) { //////////////////////////////////////////////////////////////////////////////////////////////////// -static const int NUM_THREADS_PER_WARP = 32; -static const int NUM_THREADS_PER_HALF_WARP = NUM_THREADS_PER_WARP / 2; -static const int NUM_THREADS_PER_QUAD = 4; -static const int NUM_THREADS_PER_QUAD_PAIR = NUM_THREADS_PER_QUAD * 2; +static const int NumThreadsPerWarp = 32; +static const int NumThreadsPerWarpGroup = 128; +static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; +static const int NumThreadsPerQuad = 4; +static const int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -197,6 +198,28 @@ CUTLASS_HOST_DEVICE bool thread0() { #endif } +/// Returns a warp-uniform value indicating the canonical warp index of the calling threads. +/// Threads within the warp must be converged. +CUTLASS_DEVICE +int canonical_warp_idx() { + #if defined(__CUDA_ARCH__) + return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarp, 0); + #else + return 0; + #endif +} + +/// Returns a warp-uniform value indicating the canonical warp group index of the calling threads. +/// Threads within the warp must be converged. +CUTLASS_DEVICE +int canonical_warp_group_idx() { + #if defined(__CUDA_ARCH__) + return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarpGroup, 0); + #else + return 0; + #endif +} + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/device_kernel.h b/include/cutlass/device_kernel.h index d2903ac352..68042e3fb0 100644 --- a/include/cutlass/device_kernel.h +++ b/include/cutlass/device_kernel.h @@ -34,7 +34,24 @@ #pragma once -#include "cutlass/cutlass.h" +// __grid_constant__ was introduced in CUDA 11.7. +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) +# define CUTLASS_GRID_CONSTANT_SUPPORTED +#endif + +// __grid_constant__ can be enabled only on SM70+ +#if defined(CUTLASS_GRID_CONSTANT_SUPPORTED) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) +# define CUTLASS_GRID_CONSTANT_ENABLED +#endif + +#if ! defined(CUTLASS_GRID_CONSTANT) +# if defined(CUTLASS_GRID_CONSTANT_ENABLED) +# define CUTLASS_GRID_CONSTANT __grid_constant__ +# else +# define CUTLASS_GRID_CONSTANT +# endif +#endif + //////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -75,5 +92,22 @@ void Kernel2(typename Operator::Params params) { //////////////////////////////////////////////////////////////////////////////// -} /// namespace cutlass +// +// 3.0 specific launch +// +//////////////////////////////////////////////////////////////////////////////// + +/// Generic CUTLASS kernel template. +template +__global__ __launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) +void device_kernel(CUTLASS_GRID_CONSTANT typename Operator::Params const params) +{ + // Dynamic shared memory base pointer + extern __shared__ char smem[]; + + Operator op; + op(params, smem); +} +//////////////////////////////////////////////////////////////////////////////// +} /// namespace cutlass diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp new file mode 100644 index 0000000000..5b1b924549 --- /dev/null +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -0,0 +1,49 @@ +/*************************************************************************************************** + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class... Args +> +struct CollectiveEpilogue { + static_assert(std::is_void_v, "Could not find an epilogue specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "default_epilogue.hpp" +#include "epilogue.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp new file mode 100644 index 0000000000..71499b5d38 --- /dev/null +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -0,0 +1,195 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/numeric/int.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies an element wise operation to all elements within the fragment +/// and writes them out to destination storage. +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_ +> +class DefaultEpilogue { +public: + // + // Type Aliases + // + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + // Params of epilogue::collective contain the epilogue::thread params + struct Params { + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + typename ThreadEpilogueOp::Params thread_params{}; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(Args const& args, void* workspace) { + (void) workspace; + return {args.epilogue_params}; + } + + CUTLASS_HOST_DEVICE + DefaultEpilogue(Params const& params_) : params(params_) { } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + (void) smem_buf; + ThreadEpilogueOp epilogue_op{params.thread_params}; + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + // Represent the full output tensor + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Make an identity coordinate tensor for predicating our output MN tile + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i)); + } + } + } + } + +private: + Params params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/default_transposed_epilogue.hpp b/include/cutlass/epilogue/collective/default_transposed_epilogue.hpp new file mode 100644 index 0000000000..7e38acd75b --- /dev/null +++ b/include/cutlass/epilogue/collective/default_transposed_epilogue.hpp @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/numeric/int.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies an element wise operation to all elements within the fragment +/// and writes them out to destination storage. +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_ +> +class DefaultTransposedEpilogue { + +public: + // + // Type Aliases + // + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + // Params of epilogue::collective contain the epilogue::thread params + struct Params { + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + typename ThreadEpilogueOp::Params thread_params{}; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(Args const& args, void* workspace) { + (void) workspace; + return {args.epilogue_params}; + } + + CUTLASS_HOST_DEVICE + DefaultTransposedEpilogue(Params const& params_) : params(params_) { } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + (void) smem_buf; + ThreadEpilogueOp epilogue_op{params.thread_params}; + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + // Tranpose stride C/D. + auto stride_c = make_stride(get<1>(params.dC), get<0>(params.dC), get<2>(params.dC)); + auto stride_d = make_stride(get<1>(params.dD), get<0>(params.dD), get<2>(params.dD)); + + // Represent the full output tensor + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i)); + } + } + } + } + +private: + Params params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/epilogue.hpp b/include/cutlass/epilogue/collective/epilogue.hpp new file mode 100644 index 0000000000..565e752ea0 --- /dev/null +++ b/include/cutlass/epilogue/collective/epilogue.hpp @@ -0,0 +1,322 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies an element wise operation to all elements within the fragment +/// and writes it out to destination storage. +/// +/// Ways to generalize this: +/// - CTA tile shape +/// - vectorization requirements (GMEM) +/// - vectoriz(able) transform() +/// +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class SmemLayout_, + class CopyAtomR2S_, + class TiledCopyS2R_, + class CopyAtomR2G_ +> +class Epilogue { +public: + // + // Type Aliases + // + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + using SmemLayout = SmemLayout_; + using CopyAtomR2S = CopyAtomR2S_; + using TiledCopyS2R = TiledCopyS2R_; + using CopyAtomR2G = CopyAtomR2G_; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage + { + cute::array_aligned> smem_epilogue; + }; + + // Params of epilogue::collective contain the epilogue::thread params + struct Params { + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + typename ThreadEpilogueOp::Params thread_params{}; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(Args const& args, void* workspace) { + (void) workspace; + return {args.epilogue_params}; + } + + CUTLASS_HOST_DEVICE + Epilogue(Params const& params_) : params(params_) { }; + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // synchronizing function for smem reads/writes +#if CUDA_BARRIER_ENABLED + auto synchronize = [] () { NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, 0); }; +#else + auto synchronize = [] () { __syncthreads(); }; +#endif + + ThreadEpilogueOp epilogue_op{this->params.thread_params}; + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + // Represent the full output tensor + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Construct a tensor in SMEM that we can partition for rearranging data + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sC = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) + + // Partition sC to match the accumulator partitioning + auto tC = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma).get_thread_slice(thread_idx); + Tensor tCaC = tC.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor tCsC = tC.partition_D(sC); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Tile gD and gC by the shape of SmemLayout first + auto tile = make_shape(size<0>(sC), size<1>(sC)); + Tensor gCt = local_tile(gC, tile, _); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor gDt = local_tile(gD, tile, _); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + + // Partition sC, gC, and gD for the output + auto tD = TiledCopyS2R{}.get_thread_slice(thread_idx); + Tensor tDsC = tD.partition_S(sC); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tDgC = tD.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor tDgD = tD.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + // Allocate intermediate registers on the dst tensors + Tensor tDrC = make_tensor(take<0,3>(shape(tDgC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tDrD = make_tensor(shape(tDrC)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + + // Repeat the D-partitioning for coordinates and predication + Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor cDt = local_tile(cD, tile, _); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor tDcD = tD.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + CUTE_STATIC_ASSERT(size<1>(tCaC) % size<3>(tDgC) == 0); // TILE_M divides MMA_M + CUTE_STATIC_ASSERT(size<2>(tCaC) % size<4>(tDgC) == 0); // TILE_N divides MMA_N + CUTE_STATIC_ASSERT(typename TiledCopyS2R::TiledNumThr{} == size<0>(typename TiledMma::AtomLayoutC_TV{})); + +#if 0 + if (thread_idx == 0 && m_coord == 0 && n_coord == 0) { + print("aC : "); print(accumulators.layout()); print("\n"); + print("gC : "); print(gC.layout()); print("\n"); + print("gD : "); print(gD.layout()); print("\n"); + print("sC : "); print(sC.layout()); print("\n"); + print("\n"); + print("tCsC : "); print(tCsC.layout()); print("\n"); + print("tCaC : "); print(tCaC.layout()); print("\n"); + print("\n"); + print("gDt : "); print(gDt.layout()); print("\n"); + print("tDsC : "); print(tDsC.layout()); print("\n"); + print("tDrC : "); print(tDrC.layout()); print("\n"); + print("\n"); + print("tDrD : "); print(tDrD.layout()); print("\n"); + print("tDgC : "); print(tDgC.layout()); print("\n"); + print("tDgD : "); print(tDgD.layout()); print("\n"); + print("\n"); + } +#endif + + // For each tiling needed for SmemLayout to cover shape(gD) + CUTLASS_PRAGMA_UNROLL + for (int step_m = 0; step_m < size<2>(cDt); ++step_m) + { + CUTLASS_PRAGMA_UNROLL + for (int step_n = 0; step_n < size<3>(cDt); ++step_n) + { + // Step 1. Copy to SMEM + CUTLASS_PRAGMA_UNROLL + for (int pipe_m = 0; pipe_m < size<1>(tCsC); ++pipe_m) { + CUTLASS_PRAGMA_UNROLL + for (int pipe_n = 0; pipe_n < size<2>(tCsC); ++pipe_n) { + int mma_m = step_m * size<1>(tCsC) + pipe_m; + int mma_n = step_n * size<2>(tCsC) + pipe_n; + + copy(tC, tCaC(_,mma_m,mma_n), tCsC(_,pipe_m,pipe_n)); + } + } + + // Step 2. Wait for SMEM writes to complete + synchronize(); + + // Step 3. Copy from SMEM into a fragment + copy(tD, tDsC, tDrC); + + // Step 4. Wait for SMEM reads to complete + synchronize(); + + Tensor tDgDmn = tDgD(_,_,_,step_m,step_n); + Tensor tDcDmn = tDcD(_,_,_,step_m,step_n); + + if (epilogue_op.is_source_needed()) { + // source is needed + Tensor tDgCmn = tDgC(_,_,_,step_m,step_n); + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tDgDmn); ++m) + { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tDgDmn); ++n) + { + // Predication + if (get<0>(tDcDmn(0,m,n)) < get<0>(residue_mnk) && + get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) + { + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tDrC); ++i) { + tDrD(i,m,n) = epilogue_op(tDrC(i,m,n), tDgCmn(i,m,n)); + } + // Step 6. Copy to GMEM + copy(CopyAtomR2G{}, tDrD(_,m,n), tDgDmn(_,m,n)); + } + } + } + } + else { + // source is not needed, avoid load and lift compute + + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tDrC); ++i) { + tDrD(i) = epilogue_op(tDrC(i)); + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tDgDmn); ++m) + { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tDgDmn); ++n) + { + // Predication + if (get<0>(tDcDmn(0,m,n)) < get<0>(residue_mnk) && + get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) + { + // Step 6. Copy to GMEM + copy(CopyAtomR2G{}, tDrD(_,m,n), tDgDmn(_,m,n)); + } + } + } + } + } + } + } + +private: + Params params; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp new file mode 100644 index 0000000000..de318d538f --- /dev/null +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -0,0 +1,39 @@ +/*************************************************************************************************** + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue { + +////////////////////////////////////////////////////////////////////////////// + +// +// Collective Epilogue Policies +// + +////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue diff --git a/include/cutlass/epilogue/thread/linear_combination.h b/include/cutlass/epilogue/thread/linear_combination.h index b22c26d0ba..0c4b3849df 100644 --- a/include/cutlass/epilogue/thread/linear_combination.h +++ b/include/cutlass/epilogue/thread/linear_combination.h @@ -62,7 +62,8 @@ template < typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling - FloatRoundStyle Round = FloatRoundStyle::round_to_nearest + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest, + typename ElementSource_ = ElementOutput_ > class LinearCombination { public: @@ -70,6 +71,8 @@ class LinearCombination { using ElementOutput = ElementOutput_; using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementCompute_; + using ElementC = ElementSource_; + using ElementD = ElementOutput_; static int const kCount = Count; static const ScaleType::Kind kScale = Scale; @@ -78,7 +81,6 @@ class LinearCombination { using ComputeFragment = Array; using ParamsBase = LinearCombinationParams; - static FloatRoundStyle const kRound = Round; /// Host-constructable parameters structure @@ -89,28 +91,28 @@ class LinearCombination { ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory CUTLASS_HOST_DEVICE - Params(): + Params(): ParamsBase( - ElementCompute(1), + ElementCompute(1), ElementCompute(0) ), - alpha(ElementCompute(1)), - beta(ElementCompute(0)), - alpha_ptr(nullptr), + alpha(ElementCompute(1)), + beta(ElementCompute(0)), + alpha_ptr(nullptr), beta_ptr(nullptr) { } CUTLASS_HOST_DEVICE Params( ElementCompute alpha, ElementCompute beta - ): + ): ParamsBase(alpha, beta), - alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { } + alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { } CUTLASS_HOST_DEVICE Params( ElementCompute alpha - ): + ): ParamsBase(alpha, ElementCompute(0)), alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { } @@ -118,7 +120,7 @@ class LinearCombination { Params( ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr - ): + ): ParamsBase(*alpha_ptr, *beta_ptr), alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { } @@ -132,13 +134,13 @@ class LinearCombination { CUTLASS_HOST_DEVICE Params( ParamsBase const& base - ): ParamsBase(base), alpha_ptr(nullptr), beta_ptr(nullptr) { + ): ParamsBase(base), alpha_ptr(nullptr), beta_ptr(nullptr) { #if defined(__CUDA_ARCH__) alpha = reinterpret_cast(base.alpha_data); beta = reinterpret_cast(base.beta_data); #else - memcpy( alpha, base.alpha_data, sizeof(ElementCompute) ); - memcpy( beta, base.alpha_data, sizeof(ElementCompute) ); + memcpy( alpha, base.alpha_data, sizeof(ElementCompute) ); + memcpy( beta, base.alpha_data, sizeof(ElementCompute) ); #endif } }; @@ -184,7 +186,7 @@ class LinearCombination { /// Computes linear scaling: D = alpha * accumulator + beta * source CUTLASS_HOST_DEVICE FragmentOutput operator()( - FragmentAccumulator const &accumulator, + FragmentAccumulator const &accumulator, FragmentOutput const &source) const { // Convert source to interal compute numeric type @@ -236,8 +238,61 @@ class LinearCombination { ComputeFragment intermediate; multiplies mul_accumulator; - intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + + return destination_converter(intermediate); + } + + // + // Specializations for scalar (for use with cute::collective::DefaultEpilogue) + // + CUTLASS_HOST_DEVICE + ElementD operator()(ElementAccumulator const accumulator, ElementC const source) const { + // Convert everything to Compute type, do compute, and then store to output type + NumericConverter accumulator_converter; + [[maybe_unused]] NumericConverter source_converter; + NumericConverter destination_converter; + + // Convert to destination numeric type + + ElementCompute converted_accumulator = accumulator_converter(accumulator); + if constexpr (Scale == ScaleType::Nothing) { + return destination_converter(converted_accumulator); + } + + // Perform binary operations + ElementCompute intermediate; + multiplies multiply; + multiply_add madd; + + if constexpr (Scale == ScaleType::NoBetaScaling) { + intermediate = source_converter(source); + } + else { + intermediate = multiply(beta_, source); // X = beta * C + uniform + } + + intermediate = madd(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + return destination_converter(intermediate); + } + + CUTLASS_HOST_DEVICE + ElementD operator()(ElementAccumulator const accumulator) const { + // Convert everything to Compute type, do compute, and then store to output type + NumericConverter accumulator_converter; + NumericConverter destination_converter; + ElementCompute converted_accumulator = accumulator_converter(accumulator); + + // Convert to destination numeric type + if constexpr (Scale == ScaleType::Nothing) { + return destination_converter(converted_accumulator); + } + + // Perform binary operations + ElementCompute intermediate; + multiplies multiply; + intermediate = multiply(alpha_, accumulator); // D = alpha * Accum return destination_converter(intermediate); } }; diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 972cf04bf1..277bad5c0f 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -56,6 +56,12 @@ struct absolute_value_op { } }; +template <> +struct absolute_value_op { + CUTLASS_HOST_DEVICE + float operator()(float lhs) const { return fabs(lhs); } +}; + template struct plus { CUTLASS_HOST_DEVICE @@ -83,6 +89,30 @@ struct multiplies { } }; +// Maximum with nan propogation +// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN +template +struct maximum_with_nan_propogation { + CUTLASS_HOST_DEVICE + T operator()(T const &lhs, T const &rhs) const { + return lhs > rhs or std::isnan(lhs) ? lhs : rhs; + } +}; + +template <> +struct maximum_with_nan_propogation { + CUTLASS_HOST_DEVICE + float operator()(float const lhs, float const rhs) const { + float res; +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); +#else + res = lhs > rhs or std::isnan(lhs) ? lhs : rhs; +#endif + return res; + } +}; + /// Squares with optional conversion template struct square { diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl new file mode 100644 index 0000000000..c1444a9840 --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -0,0 +1,414 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cute/atom/mma_traits_sm90_gmma.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// +// Some named constants +// +constexpr int tma_alignment_bytes = 16; +constexpr int cp_async_min_alignment_bytes = 4; +constexpr int sm90_smem_capacity_bytes = 232448; + +// Maps 2.x A matrix layout tag to respective GMMA major mode enum +template +constexpr cute::GMMA::Major +tag_to_gmma_major_A() { + // MN major mode is only valid for non-TF32 and non-int MMAs + if constexpr (std::is_same_v && + not std::is_same_v && + not std::is_same_v && + not std::is_same_v) { + return cute::GMMA::Major::MN; + } + else { + return cute::GMMA::Major::K; + } +} + +// Maps 2.x B matrix layout tag to respective GMMA major mode enum +template +constexpr cute::GMMA::Major +tag_to_gmma_major_B() { + // MN major mode is only valid for non-TF32 and non-int MMAs + if constexpr (std::is_same_v && + not std::is_same_v && + not std::is_same_v && + not std::is_same_v) { + return cute::GMMA::Major::MN; + } + else { + return cute::GMMA::Major::K; + } +} + +// Maps a rank-1 cute::Shape<> representing the cluster shape on to the TMA atom that should be used with it +template +constexpr auto +cluster_shape_to_tma_atom(UnimodalClusterShape unimodal_cluster_shape) { + static_assert(cute::rank(unimodal_cluster_shape) == 1, + "Use this function to figure out TMA for each mode individually."); + + if constexpr (cute::size(unimodal_cluster_shape) == 1) { + return cute::SM90_TMA_LOAD{}; + } + else { + return cute::SM90_TMA_LOAD_MULTICAST{}; + } +} + +// Generates the most efficient possible TiledCopy with cp.async copy atom given a set of parameters. +template +constexpr auto +make_cp_async_gmem_tiled_copy() { + using AlignmentType = cute::uint_byte_t(sizeof(Element)) * Alignment>; + constexpr int TileSizeMN = cute::size(TileMN{}); + constexpr int TileSizeK = cute::size(TileK{}); + + // Maximize the number of threads along the gmem major mode to promote coalesced reads + // While making sure our thread layout tiles the threadblock tile evenly + if constexpr (cute::size<1>(StrideType{}) == 1) { + // K major thread layout for K major gmem + constexpr int threads_major = TileSizeK / Alignment; + constexpr int threads_minor = ThreadCount / threads_major; + static_assert(threads_major > 0); + static_assert(ThreadCount % threads_major == 0); + static_assert(threads_minor == 0 || (TileSizeMN % threads_minor == 0)); + return make_tiled_copy( + Copy_Atom, Element>{}, + Layout,Int>, + Stride, _1>>{}, + Layout>>{}); + } + else if constexpr (cute::size<0>(StrideType{}) == 1) { + // MN major thread layout for MN major gmem + constexpr int threads_major = TileSizeMN / Alignment; + constexpr int threads_minor = ThreadCount / threads_major; + static_assert(threads_major > 0); + static_assert(ThreadCount % threads_major == 0); + static_assert(threads_minor == 0 || (TileSizeK % threads_minor == 0)); + return make_tiled_copy( + Copy_Atom, Element>{}, + Layout,Int>, + Stride< _1,Int>>{}, + Layout,_1>>{}); + } + else { + static_assert(std::is_void_v, "Unsupported gmem layout for automatic gmem tiled copy builder."); + } +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(int KernelSmemCarveout = 0) { + if constexpr (std::is_same_v) { + // 32 bytes to account for barriers etc. + constexpr int stage_barrier_bytes = 32; + constexpr int a_bytes = static_cast(sizeof(ElementA)); + constexpr int b_bytes = static_cast(sizeof(ElementB)); + constexpr int stage_bytes = + (a_bytes * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + (b_bytes * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + stage_barrier_bytes; + + return (CapacityBytes - KernelSmemCarveout) / stage_bytes; + } + else { + return StageCountType::value; + } +} + +// Kernel policy selection logic: auto dispatches to KernelTmaWarpSpecialized for now. Subject to change. +template < + class ElementA, + class ElementB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +constexpr auto +generate_gmma_dispatch_policy() { + if constexpr (std::is_base_of_v or + std::is_same_v) { + constexpr int PipelineStages = compute_stage_count_or_override< + sm90_smem_capacity_bytes, ElementA, ElementB, TileShape_MNK, StageCountType>(); + + if constexpr (std::is_same_v or + std::is_same_v) { + return MainloopSm90TmaGmmaWarpSpecialized{}; + } + else { + static_assert(sizeof(ElementA) == 0, "Invalid kernel schedule type."); + } + } + + else if constexpr (std::is_base_of_v) { + // For the persistent kernel, assume that the epilogue uses 1 MN tile worth of smem + constexpr int EpilogueTileCarveout = sizeof(ElementAccumulator) * + (size<0>(TileShape_MNK{}) * size<1>(TileShape_MNK{})); + constexpr int PipelineStages = compute_stage_count_or_override< + sm90_smem_capacity_bytes, ElementA, ElementB, TileShape_MNK, StageCountType>(EpilogueTileCarveout); + + if constexpr (std::is_same_v) { + return MainloopSm90TmaGmmaWarpSpecialized{}; + } + else { + static_assert(sizeof(ElementA) == 0, "Invalid kernel schedule type."); + } + } + + else if constexpr (std::is_base_of_v) { + constexpr int PipelineStages = compute_stage_count_or_override< + sm90_smem_capacity_bytes, ElementA, ElementB, TileShape_MNK, StageCountType>(); + + return MainloopSm90TmaGmma{}; + } + + else { + static_assert(sizeof(ElementA) == 0, "Invalid kernel schedule type."); + } +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_SS +template < + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + std::enable_if_t< + // TMA requires alignment be 16 bytes + ((sizeof(ElementA) * AlignmentA) % detail::tma_alignment_bytes == 0) && + ((sizeof(ElementB) * AlignmentB) % detail::tma_alignment_bytes == 0) && + not std::is_same_v && + // dispatch TN tf32 and int8 kernels only to TMA builder + ((sizeof(ElementA) == 2 && sizeof(ElementB) == 2) || + (std::is_same_v && std::is_same_v))> +> { + static_assert(is_static::value); + static_assert(is_static::value); + + #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(sizeof(ElementA) == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); + #endif + + // For fp32 types, map to tf32 MMA value type + using MmaElementA = std::conditional_t, tfloat32_t, ElementA>; + using MmaElementB = std::conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = detail::tag_to_gmma_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::tag_to_gmma_major_B(); + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + MmaElementA, MmaElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>())); + + using GmemTiledCopyA = decltype(detail::cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(cute::GMMA::smem_selector< + GmmaMajorA, MmaElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})) + >()); + using SmemLayoutAtomB = decltype(cute::GMMA::smem_selector< + GmmaMajorB, MmaElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})) + >()); + + using DispatchPolicy = decltype(detail::generate_gmma_dispatch_policy< + MmaElementA, MmaElementB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType>()); + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + void, // GMMA_SS does not need an SmemCopyAtom + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, // GMMA_SS does not need an SmemCopyAtom + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_CpAsync_SS +template < + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + std::enable_if_t< + // Even if we could build a TMA kernel, let the user override and use cp_async instead + std::is_same_v || + // But always guard against invalid TMA alignments and dispatch to cp_async + ((sizeof(ElementA) * AlignmentA) % detail::tma_alignment_bytes != 0) || + ((sizeof(ElementB) * AlignmentB) % detail::tma_alignment_bytes != 0) || + // dispatch non-TN tf32 and int8 kernels only to cp_async builder + ((sizeof(ElementA) != 2 || sizeof(ElementB) != 2) && + (not std::is_same_v || not std::is_same_v))> +> { + static_assert(is_static::value); + static_assert(is_static::value); + + #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(sizeof(ElementA) == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); + #endif + + // For fp32 types, map to tf32 MMA value type + using MmaElementA = std::conditional_t, tfloat32_t, ElementA>; + using MmaElementB = std::conditional_t, tfloat32_t, ElementB>; + + static_assert((sizeof(ElementA) * AlignmentA) % detail::cp_async_min_alignment_bytes == 0 and + (sizeof(ElementB) * AlignmentB) % detail::cp_async_min_alignment_bytes == 0, + "Minimum alignment required for cp.async is 4B."); + + static constexpr cute::GMMA::Major GmmaMajorA = detail::tag_to_gmma_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::tag_to_gmma_major_B(); + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + MmaElementA, MmaElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>())); + + using GmemTiledCopyA = decltype(detail::make_cp_async_gmem_tiled_copy< + 128, ElementA, AlignmentA, TagToStrideA_t, + decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using GmemTiledCopyB = decltype(detail::make_cp_async_gmem_tiled_copy< + 128, ElementB, AlignmentB, TagToStrideB_t, + decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemLayoutAtomA = decltype(cute::GMMA::smem_selector< + GmmaMajorA, MmaElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})) + >()); + + using SmemLayoutAtomB = decltype(cute::GMMA::smem_selector< + GmmaMajorB, MmaElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})) + >()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override< + detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB, TileShape_MNK, StageCountType>(); + + using CollectiveOp = CollectiveMma< + MainloopSm90CpAsyncGmma, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + void, // GMMA_SS does not need an SmemCopyAtom + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, // GMMA_SS does not need an SmemCopyAtom + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp new file mode 100644 index 0000000000..3cd68a41de --- /dev/null +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -0,0 +1,78 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "collective_mma.hpp" + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Used to specify stage counts or dispatch to automatic computation of stage count +template +struct StageCount { static constexpr int value = num_stages; }; +struct StageCountAuto {}; + +// Used to automatically let the builder pick the kernel schedule. +// Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp +struct KernelScheduleAuto {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ArchTag, + class OpClass, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType, + class Enable = void +> +struct CollectiveBuilder { + static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "builders/sm90_gmma_builder.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp new file mode 100644 index 0000000000..a2a9067571 --- /dev/null +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -0,0 +1,71 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class TileShape, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class TiledMma, + class GmemTiledCopyA, + class SmemLayoutAtomA, + class SmemCopyAtomA, + class TransformA, + class GmemTiledCopyB, + class SmemLayoutAtomB, + class SmemCopyAtomB, + class TransformB +> +struct CollectiveMma { + static_assert(sizeof(ElementA) == 0, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "sm70_mma_twostage.hpp" +#include "sm80_mma_multistage.hpp" +#include "sm90_mma_multistage_gmma_ss.hpp" +#include "sm90_mma_tma_gmma_ss.hpp" +#include "sm90_mma_tma_gmma_ss_warpspecialized.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm70_mma_twostage.hpp b/include/cutlass/gemm/collective/sm70_mma_twostage.hpp new file mode 100644 index 0000000000..11e5515aed --- /dev/null +++ b/include/cutlass/gemm/collective/sm70_mma_twostage.hpp @@ -0,0 +1,588 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/tensor_predicate.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm70TwoStageUnpredicated, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm70TwoStageUnpredicated; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})))); + + struct SharedStorage + { + cute::array_aligned> smem_a; + cute::array_aligned> smem_b; + }; + + struct Params { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(Args const& args, void* workspace) { + (void) workspace; + return {args.ptr_A, args.dA, args.ptr_B, args.dB}; + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, + TensorB gB, + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf) + { + using namespace cute; + + (void)residue_mnk; + + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 2, + "MainloopTwoStage must not have a smem shape with a pipeline mode."); + static_assert(rank(SmemLayoutB{}) == 2, + "MainloopTwoStage must not have a smem shape with a pipeline mode."); + + // Construct shared memory tiles + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_a; + GmemTiledCopyB gmem_tiled_copy_b; + auto copy_a_thr = gmem_tiled_copy_a.get_slice(thread_idx); + auto copy_b_thr = gmem_tiled_copy_b.get_slice(thread_idx); + + Tensor tAgA = copy_a_thr.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = copy_a_thr.partition_D(sA); // (ACPY,ACPY_M,ACPY_K) + Tensor tBgB = copy_b_thr.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = copy_b_thr.partition_D(sB); // (BCPY,BCPY_N,BCPY_K) + + // Allocate the register tiles for double buffering -- same shape as partitioned data + Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K) + Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_N,BCPY_K) + + // Tile MMA compute thread partitions and allocate accumulators + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_M,MMA_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + + // + // Copy Atom retiling + // + + auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx); + Tensor tCsA = thr_copy_A.partition_S(sA); + Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + + auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx); + Tensor tCsB = thr_copy_B.partition_S(sB); + Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + + // + // Prologue + // + + // Copy gmem to rmem for the first k_tile + copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tArA); + copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBrB); + if (--k_tile_count > 0) ++k_tile_iter; + // Copy rmem to smem + copy(tArA, tAsA); + copy(tBrB, tBsB); + // Clear accumulators + __syncthreads(); + + // Load A, B smem->rmem for k=0 + copy(tCsA(_,_,0), tCrA_copy_view(_,_,0)); + copy(tCsB(_,_,0), tCrB_copy_view(_,_,0)); + // + // Mainloop + // + + // Size of the k-tiles's outer product mode (k) + auto K_BLOCK_MAX = size<2>(tCrA); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > -1) + { + // Pipeline the outer products with a static for loop + for_each(make_int_sequence{}, [&] (auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + __syncthreads(); + + // Copy rmem to smem + copy(tArA, tAsA); + copy(tBrB, tBsB); + __syncthreads(); + } + + // Load A, B smem->rmem for k+1 + int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static + copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + if (k_block == 0) + { + // Copy gmem to rmem + copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tArA); + copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBrB); + if (--k_tile_count > 0) ++k_tile_iter; + } + + // transform before compute + cute::transform(tCrA(_,_,k_block), TransformA{}); + cute::transform(tCrB(_,_,k_block), TransformB{}); + + // Thread-level register gemm for k + // disambiguate gemm (shared with the namespace name) + cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); + }); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm70TwoStage, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm70TwoStage; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})))); + + struct SharedStorage + { + cute::array_aligned> smem_a; + cute::array_aligned> smem_b; + }; + + struct Params { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(Args const& args, void* workspace) { + (void) workspace; + return {args.ptr_A, args.dA, args.ptr_B, args.dB}; + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, + TensorB gB, + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf) + { + using namespace cute; + + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 2, + "MainloopTwoStage must not have a smem shape with a pipeline mode."); + static_assert(rank(SmemLayoutB{}) == 2, + "MainloopTwoStage must not have a smem shape with a pipeline mode."); + + // Construct shared memory tiles + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + gA.data() = &gA(0, get<2>(residue_mnk), 0); + gB.data() = &gB(0, get<2>(residue_mnk), 0); + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_a; + GmemTiledCopyB gmem_tiled_copy_b; + auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); + auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // Allocate the register tiles for double buffering -- same shape as partitioned data + Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K) + Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_N,BCPY_K) + + // + // PREDICATES + // + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + // Construct identity layout for sA and sB + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tAcA = gmem_thr_copy_a.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tBcB = gmem_thr_copy_b.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n + } + + // + // PREFETCH + // + + // Clear the rmem tiles to account for predicated off loads + clear(tArA); + clear(tBrB); + + // Start async loads for 0th k-tile, where we take care of the k residue + { + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tArA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) + copy_if(gmem_tiled_copy_a, tApA(_,k), tAgAk(_,_,k), tArA(_,_,k)); + } + } + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBrB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) + copy_if(gmem_tiled_copy_b, tBpB(_,k), tBgBk(_,_,k), tBrB(_,_,k)); + } + } + ++k_tile_iter; + --k_tile_count; + } + + // Tile MMA compute thread partitions and allocate accumulators + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA)); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB)); // (MMA,MMA_M,MMA_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + + // + // Copy Atom retiling + // + + auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx); + Tensor tCsA = thr_copy_A.partition_S(sA); + Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + + auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx); + Tensor tCsB = thr_copy_B.partition_S(sB); + Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + + // + // Prologue + // + + // Copy rmem to smem + copy(tArA, tAsA); + copy(tBrB, tBsB); + // Clear accumulators + __syncthreads(); + + // Load A, B smem->rmem for k=0 + copy(tCsA(_,_,0), tCrA_copy_view(_,_,0)); + copy(tCsB(_,_,0), tCrB_copy_view(_,_,0)); + // + // Mainloop + // + + // Size of the k-tiles's outer product mode (k) + auto K_BLOCK_MAX = size<2>(tCrA); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > -1) + { + // Pipeline the outer products with a static for loop + for_each(make_int_sequence{}, [&] (auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + __syncthreads(); + + // Copy rmem to smem + copy(tArA, tAsA); + copy(tBrB, tBsB); + __syncthreads(); + } + + // Load A, B smem->rmem for k+1 + int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static + copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + if (k_block == 0) + { + if (k_tile_count <= 0) { + clear(tApA); + clear(tBpB); + } + copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tArA); + copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBrB); + ++k_tile_iter; + --k_tile_count; + } + + // transform before compute + cute::transform(tCrA(_,_,k_block), TransformA{}); + cute::transform(tCrB(_,_,k_block), TransformB{}); + + // Thread-level register gemm for k + // disambiguate gemm (shared with the namespace name) + cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); + }); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm80_mma_multistage.hpp b/include/cutlass/gemm/collective/sm80_mma_multistage.hpp new file mode 100644 index 0000000000..6ba6ccc008 --- /dev/null +++ b/include/cutlass/gemm/collective/sm80_mma_multistage.hpp @@ -0,0 +1,680 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm80CpAsyncUnpredicated, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm80CpAsyncUnpredicated; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + static_assert(DispatchPolicy::Stages >= 2, "CpAsync mainloop must have at least 2 stages in the pipeline."); + + struct SharedStorage + { + cute::array_aligned> smem_a; + cute::array_aligned> smem_b; + }; + + struct Params { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(Args const& args, void* workspace) { + (void) workspace; + return {args.ptr_A, args.dA, args.ptr_B, args.dB}; + } + + /// Perform a collective-scoped matrix multiply-accumulate + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, + TensorB gB, + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf) + { + using namespace cute; + + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 3, + "MainloopSm80CpAsync must have a pipeline mode in the smem layout."); + static_assert(rank(SmemLayoutB{}) == 3, + "MainloopSm80CpAsync must have a pipeline mode in the smem layout."); + + // Construct shared memory tiles + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<0>(gA) == size<0>(sA)); // BLK_M + CUTE_STATIC_ASSERT_V(size<1>(gA) == size<1>(sA)); // BLK_K + CUTE_STATIC_ASSERT_V(size<0>(gB) == size<0>(sB)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(gB) == size<1>(sB)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // BLK_K + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_A; + GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // + // PREDICATES + // + + (void) residue_mnk; + //assert(residue_mnk == make_tuple(0,0,0)); + + // + // PREFETCH + // + + // Start async loads for all pipes but the last + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 0; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { + copy(gmem_tiled_copy_A, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); + copy(gmem_tiled_copy_B, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); + cp_async_fence(); + --k_tile_count; + if (k_tile_count > 0) { ++k_tile_iter; } + } + + // + // MMA Atom partitioning + // + + // Tile MMA compute thread partitions and allocate accumulators + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + CUTE_STATIC_ASSERT_V(size(gmem_tiled_copy_A) == size(tiled_mma)); + CUTE_STATIC_ASSERT_V(size(gmem_tiled_copy_B) == size(tiled_mma)); + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K + + // + // PIPELINED MAIN LOOP + // + + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = DispatchPolicy::Stages-1; + + Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read); + Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + // PREFETCH register pipeline + if (K_BLOCK_MAX > 1) { + // Wait until our first prefetched tile is loaded in + cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + copy(smem_tiled_copy_A, tCsA_p(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); + } + + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) + { + // Pipeline the outer products with a static for loop. + // + // Note, the for_each() function is required here to ensure `k_block` is of type Int. + for_each(make_int_sequence{}, [&] (auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + // Slice the smem_pipe_read smem + tCsA_p = tCsA(_,_,_,smem_pipe_read); + tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Commit the smem for smem_pipe_read + cp_async_wait(); + __syncthreads(); + } + + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static + copy(smem_tiled_copy_A, tCsA_p(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_B, tCsB_p(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + copy(gmem_tiled_copy_A, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); + copy(gmem_tiled_copy_B, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); + cp_async_fence(); + if (k_tile_count > 0) { ++k_tile_iter; } + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? 0 : smem_pipe_read; + } + + // Transform before compute + cute::transform(tCrA(_,_,k_block), TransformA{}); + cute::transform(tCrB(_,_,k_block), TransformB{}); + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); + }); + + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm80CpAsync, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm80CpAsync; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + static_assert(DispatchPolicy::Stages >= 2, "CpAsync mainloop must have at least 2 stages in the pipeline."); + + struct SharedStorage + { + cute::array_aligned> smem_a; + cute::array_aligned> smem_b; + }; + + struct Params { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(Args const& args, void* workspace) { + (void) workspace; + return {args.ptr_A, args.dA, args.ptr_B, args.dB}; + } + + /// Perform a collective-scoped matrix multiply-accumulate + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, // (BLK_M, BLK_K, K_TILES) + TensorB gB, // (BLK_N, BLK_K, K_TILES) + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf) + { + using namespace cute; + + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + // Construct shared memory tiles + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<0>(gA) == size<0>(sA)); // BLK_M + CUTE_STATIC_ASSERT_V(size<1>(gA) == size<1>(sA)); // BLK_K + CUTE_STATIC_ASSERT_V(size<0>(gB) == size<0>(sB)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(gB) == size<1>(sB)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // BLK_K + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + gA.data() = &gA(0, get<2>(residue_mnk), 0); + gB.data() = &gB(0, get<2>(residue_mnk), 0); + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_A; + GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // + // PREDICATES + // + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + // Construct identity layout for sA and sB + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tAcA = gmem_thr_copy_A.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tBcB = gmem_thr_copy_B.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n + } + + // + // PREFETCH + // + + // Clear the smem tiles to account for predicated off loads + clear(tAsA); + clear(tBsB); + + // Start async loads for 0th k-tile, where we take care of the k residue + { + constexpr int k_pipe = 0; + + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tAsA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) + copy_if(gmem_tiled_copy_A, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,k_pipe)); + } + } + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) + copy_if(gmem_tiled_copy_B, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,k_pipe)); + } + } + cp_async_fence(); + ++k_tile_iter; + --k_tile_count; + } + + // Start async loads for 1st k-tile onwards, no k-residue handling needed + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 1; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { + if (k_tile_count <= 0) { + clear(tApA); + clear(tBpB); + } + copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); // CpAsync + copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); // CpAsync + cp_async_fence(); + ++k_tile_iter; + --k_tile_count; + } + + // + // MMA Atom partitioning + // + + // Tile MMA compute thread partitions and allocate accumulators + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K + + // + // PIPELINED MAIN LOOP + // + + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = DispatchPolicy::Stages-1; + + Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read); + Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + // PREFETCH register pipeline + if (K_BLOCK_MAX > 1) { + // Wait until our first prefetched tile is loaded in + cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + copy(smem_tiled_copy_A, tCsA_p(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); + } + + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) + { + // Pipeline the outer products with a static for loop. + // + // Note, the for_each() function is required here to ensure `k_block` is of type Int. + for_each(make_int_sequence{}, [&] (auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + // Slice the smem_pipe_read smem + tCsA_p = tCsA(_,_,_,smem_pipe_read); + tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Commit the smem for smem_pipe_read + cp_async_wait(); + __syncthreads(); + } + + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static + copy(smem_tiled_copy_A, tCsA_p(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_B, tCsB_p(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + // Set all predicates to false if we are going to overshoot bounds + if (k_tile_count <= 0) { + clear(tApA); + clear(tBpB); + } + copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); + copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); + cp_async_fence(); + ++k_tile_iter; + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? 0 : smem_pipe_read; + } + + // Transform before compute + cute::transform(tCrA(_,_,k_block), TransformA{}); + cute::transform(tCrB(_,_,k_block), TransformB{}); + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); + }); + + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp new file mode 100644 index 0000000000..3b1921b9cc --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp @@ -0,0 +1,596 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" + +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class ClusterShape, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90CpAsyncGmmaUnpredicated, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90CpAsyncGmmaUnpredicated; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(std::is_base_of::value && + std::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + struct SharedStorage + { + cute::array_aligned> smem_a; + cute::array_aligned> smem_b; + }; + + struct Params { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(Args const& args, void* workspace) { + (void) workspace; + return {args.ptr_A, args.dA, args.ptr_B, args.dB}; + } + + /// Perform a collective-scoped matrix multiply-accumulate + template < + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + TensorA gA, + TensorB gB, + FrgTensorC& accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf, + Params const& mainloop_params) + { + using namespace cute; + + (void) residue_mnk; + + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(std::is_same::value, + "SM90 warpgroup MMA must specify transforms through MMA_Atom."); + static_assert(std::is_same::value, + "SM90 warpgroup MMA must specify transforms through MMA_Atom."); + static_assert(std::is_same::value, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(std::is_same::value, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_a; + GmemTiledCopyB gmem_tiled_copy_b; + auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); + auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // Tile MMA atom and compute thread partitions across A, B and C + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + + // Allocate registers for pipelining + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tAsA)); // PIPE + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tBsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // Prologue + // + + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 0; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { + copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); + copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); + cp_async_fence(); + ++k_tile_iter; + --k_tile_count; + } + + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = DispatchPolicy::Stages-1; + + // + // Pipelined Main Loop + // + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) + { + // Copy gmem to smem before computing gemm on each k-pipe + // pipe index in smem where the next gmem tile will be read into + copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); + copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); + cp_async_fence(); + if (k_tile_count > 0) { ++k_tile_iter; } + + // + // Compute on k_tile + // + warpgroup_fence_operand(accum); + warpgroup_arrive(); + + cp_async_wait(); + cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read), tCrB(_,_,_,smem_pipe_read), accum); + warpgroup_commit_batch(); + + // + // Advance the pipe + // + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? smem_pipe_read = 0 : smem_pipe_read; + + ++smem_pipe_write; + smem_pipe_write = (smem_pipe_write == DispatchPolicy::Stages) ? smem_pipe_write = 0 : smem_pipe_write; + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(accum); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class ClusterShape, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90CpAsyncGmma, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90CpAsyncGmma; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(std::is_base_of::value && + std::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + struct SharedStorage + { + cute::array_aligned> smem_a; + cute::array_aligned> smem_b; + }; + + struct Params { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(Args const& args, void* workspace) { + (void) workspace; + return {args.ptr_A, args.dA, args.ptr_B, args.dB}; + } + + /// Perform a collective-scoped matrix multiply-accumulate + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, + TensorB gB, + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf) + { + using namespace cute; + + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(std::is_same::value, + "SM90 warpgroup MMA must specify transforms through MMA_Atom."); + static_assert(std::is_same::value, + "SM90 warpgroup MMA must specify transforms through MMA_Atom."); + static_assert(std::is_same::value, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(std::is_same::value, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + gA.data() = &gA(0, get<2>(residue_mnk), 0); + gB.data() = &gB(0, get<2>(residue_mnk), 0); + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_a; + GmemTiledCopyB gmem_tiled_copy_b; + auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); + auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // + // PREDICATES + // + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + // Construct identity layout for sA and sB + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tAcA = gmem_thr_copy_a.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tBcB = gmem_thr_copy_b.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n + } + + // + // Prologue/PREFETCH + // + + // Clear the smem tiles to account for predicated off loads + clear(tAsA); + clear(tBsB); + + // Start async loads for 0th k-tile, where we take care of the k residue + { + constexpr int k_pipe = 0; + + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tAsA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) + copy_if(gmem_tiled_copy_a, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,k_pipe)); + } + } + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) + copy_if(gmem_tiled_copy_b, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,k_pipe)); + } + } + cp_async_fence(); + ++k_tile_iter; + --k_tile_count; + } + + // Start async loads for 1st k-tile onwards, no k-residue handling needed + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 1; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { + if (k_tile_count <= 0) { + clear(tApA); + clear(tBpB); + } + copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); // CpAsync + copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); // CpAsync + cp_async_fence(); + ++k_tile_iter; + --k_tile_count; + } + + // + // MMA Atom partitioning + // + + // Tile MMA atom and compute thread partitions across A, B and C + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + + // Allocate registers for pipelining + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(src_accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(src_accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tAsA)); // PIPE + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tBsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = DispatchPolicy::Stages-1; + + // + // Pipelined Main Loop + // + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) + { + // + // Copy gmem to smem for *k_tile_iter + // + if (k_tile_count <= 0) { + clear(tApA); + clear(tBpB); + } + copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); // CpAsync + copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); // CpAsync + cp_async_fence(); + ++k_tile_iter; + + // + // Compute on k_tile + // + warpgroup_fence_operand(accum); + warpgroup_arrive(); + + cp_async_wait(); + cute::gemm(tiled_mma, accum, tCrA(_,_,_,smem_pipe_read), tCrB(_,_,_,smem_pipe_read), src_accum); + warpgroup_commit_batch(); + + // + // Advance the pipe + // + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? smem_pipe_read = 0 : smem_pipe_read; + + ++smem_pipe_write; + smem_pipe_write = (smem_pipe_write == DispatchPolicy::Stages) ? smem_pipe_write = 0 : smem_pipe_write; + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(accum); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp new file mode 100644 index 0000000000..25eaffb74b --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp @@ -0,0 +1,480 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/pipeline.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class ClusterShape, + int PipelineAsyncMmaStages, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmma, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmma; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaAsync< + DispatchPolicy::Stages, + typename DispatchPolicy::ClusterShape>; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + Step<_2,_1,_3>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + Step<_2,_1,_3>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(std::is_base_of::value && + std::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(std::is_same_v || std::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(std::is_same_v || std::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = std::is_same_v; + static constexpr bool ConvertF32toTF32B = std::is_same_v; + using InternalElementA = std::conditional_t>>; + using InternalElementB = std::conditional_t>>; + + struct SharedStorage + { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + alignas(16) PipelineStorage pipeline_storage; + }; + + struct Params { + InternalElementA const* ptr_A; + StrideA dA; + InternalElementB const* ptr_B; + StrideB dB; + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(ptr_A, repeat_like(StrideA{}, int32_t(0)), dA), + SmemLayoutA{}(_,_,0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(ptr_B, repeat_like(StrideB{}, int32_t(0)), dB), + SmemLayoutB{}(_,_,0), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(Args const& args, void* workspace) { + (void) workspace; + // Optionally append _1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + auto reinterpreted_ptr_A = reinterpret_cast(args.ptr_A); + auto reinterpreted_ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(reinterpreted_ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(reinterpreted_ptr_B, make_layout(make_shape(N,K,L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return { + reinterpreted_ptr_A, + args.dA, + reinterpreted_ptr_B, + args.dB, + tma_load_a, + tma_load_b + }; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Perform a collective-scoped matrix multiply-accumulate + template < + class TensorA, class TMA_LOAD_A, + class TensorB, class TMA_LOAD_B, + class FrgTensorC, + class KTileIterator + > + CUTLASS_DEVICE void + operator() ( + TensorA const& gA, TMA_LOAD_A& tma_load_a, + TensorB const& gB, TMA_LOAD_B& tma_load_b, + FrgTensorC& accum, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + char* shared_memory, + Params const& mainloop_params) + { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(std::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(std::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + SharedStorage& storage = *reinterpret_cast(shared_memory); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + dim3 cluster_local_block_id = cute::block_id_in_cluster(); + auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + // + // Prepare TMA membars and PREFETCH + // + + // Number of pipelined k-tiles in smem + constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + + // NOTE: Another parameter: Partition the pipeline between active MMAs and active TMAs + // Tunable via the dispatch policy to tollerate latencies evenly across the math and compute stages + // K_PIPE_MMAS: The max number of active MMA pipes at beginning of every loop + // K_PIPE_TMAS: The max number of active TMA pipes at beginning of every loop (geq 1) + constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages; + constexpr int K_PIPE_TMAS = K_PIPE_MAX - K_PIPE_MMAS; + static_assert(0 <= K_PIPE_MMAS && K_PIPE_MMAS < K_PIPE_MAX); + static_assert(0 < K_PIPE_TMAS && K_PIPE_TMAS <= K_PIPE_MAX); + + static_assert(K_PIPE_MMAS < K_PIPE_MAX - 1); + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr uint32_t TmaTransactionBytes = static_cast( + (size<0>(sA) * size<1>(sA) * sizeof(InternalElementA)) + + (size<0>(sB) * size<1>(sB) * sizeof(InternalElementB))); + + + // Obtain warp index + int warp_idx = canonical_warp_idx(); + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + + PipelineParams params; + params.transaction_bytes = TmaTransactionBytes; + params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; + params.is_leader = warp_group_thread_idx == 0; + params.num_consumers = NumThreadsPerWarpGroup; + + MainloopPipeline pipeline( + storage.pipeline_storage, + params); + + // State variables used for iterating the circular buffer + // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA + // smem_pipe_write is used by the producer of SMEM data - i.e TMA + PipelineState smem_pipe_read; + PipelineState smem_pipe_release; + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } + else { + __syncthreads(); + } + + // Set predicate for the lowest lane_id in the warp + int lane_predicate = cute::elect_one_sync(); + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + // Keep a copy to know when to stop issuing loads + int k_tile_count_tma = k_tile_count; + + // Issue TmaLoads (Prologue fetches) + if (warp_idx == 0 && lane_predicate == 1) { + // Maps the tile -> block, value + if constexpr (std::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (std::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Issue the prologue loads + int prologue_tma_count = min(K_PIPE_MAX, k_tile_count); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < prologue_tma_count; ++stage) { + pipeline.producer_acquire(smem_pipe_write); + using BarrierType = typename MainloopPipeline::ValueType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(stage); + + copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,stage)); + copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,stage)); + ++k_tile_iter; + ++smem_pipe_write; + } + k_tile_count_tma -= prologue_tma_count; + } + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tAsA)); // PIPE + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tBsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + __syncthreads(); + + warpgroup_fence_operand(accum); + // Prologue MMAs + CUTLASS_PRAGMA_UNROLL + for (int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + prologue_mma_count > 0; --prologue_mma_count) + { + // WAIT on smem_pipe_read until it's data is available + pipeline.consumer_wait(smem_pipe_read); + warpgroup_arrive(); + cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read.index()), tCrB(_,_,_,smem_pipe_read.index()), accum); // (V,M,K) x (V,N,K) => (V,M,N) + warpgroup_commit_batch(); + ++smem_pipe_read; + --k_tile_count; + } + warpgroup_fence_operand(accum); + + // + // PIPELINED MAIN LOOP + // + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until data is available + pipeline.consumer_wait(smem_pipe_read); + + // + // Compute on k_tile + // + + warpgroup_fence_operand(accum); + warpgroup_arrive(); + cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read.index()), tCrB(_,_,_,smem_pipe_read.index()), accum); // (V,M,K) x (V,N,K) => (V,M,N) + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK wr stage, done _computing_ on it + + // + // Copy gmem to smem for *k_tile_iter + // + + // Do Acquire & Load only if needed - helps with both performance and also corner case illegal barrier-ops + if (warp_idx == 0 && lane_predicate == 1 && (k_tile_count_tma > 0) ) { + pipeline.producer_acquire(smem_pipe_write); // LOCK wr stage, for _writing_ + + using BarrierType = typename MainloopPipeline::ValueType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write.index()); + + copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write.index())); + copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write.index())); + ++smem_pipe_write; + ++k_tile_iter; + --k_tile_count_tma; + } + + // Advance consumer pipeline + ++smem_pipe_read; + ++smem_pipe_release; + } + + // Wait on all GMMAs + warpgroup_wait<0>(); + warpgroup_fence_operand(accum); + + // Workaround for ensuring Smem destruction doesn't happen accidentally + if constexpr (size(typename DispatchPolicy::ClusterShape{}) > 1) { + cute::cluster_arrive(); + cute::cluster_wait(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp new file mode 100644 index 0000000000..41b0f13b65 --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp @@ -0,0 +1,494 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/pipeline.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecialized, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaAsync< + DispatchPolicy::Stages, + typename DispatchPolicy::ClusterShape>; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + Step<_2,_1,_3>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + Step<_2,_1,_3>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(std::is_base_of::value && + std::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(std::is_same_v || std::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(std::is_same_v || std::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = std::is_same_v; + static constexpr bool ConvertF32toTF32B = std::is_same_v; + using InternalElementA = std::conditional_t>>; + using InternalElementB = std::conditional_t>>; + + struct SharedStorage + { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + alignas(16) PipelineStorage pipeline_storage; + }; + + struct Params { + InternalElementA const* ptr_A; + StrideA dA; + InternalElementB const* ptr_B; + StrideB dB; + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(ptr_A, repeat_like(StrideA{}, int32_t(0)), dA), + SmemLayoutA{}(_,_,0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(ptr_B, repeat_like(StrideB{}, int32_t(0)), dB), + SmemLayoutB{}(_,_,0), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(Args const& args, void* workspace) { + (void) workspace; + // Optionally append _1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + auto reinterpreted_ptr_A = reinterpret_cast(args.ptr_A); + auto reinterpreted_ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(reinterpreted_ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(reinterpreted_ptr_B, make_layout(make_shape(N,K,L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return { + reinterpreted_ptr_A, + args.dA, + reinterpreted_ptr_B, + args.dB, + tma_load_a, + tma_load_b + }; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes = + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof(ElementA)))+ + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof(ElementB))); + + CUTLASS_DEVICE + static MainloopPipeline make_pipeline(char* shared_memory, PipelineParams params){ + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + return {shared_storage.pipeline_storage, params}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TMA_LOAD_A, + class TensorB, class TMA_LOAD_B, + class KTileIterator + > + CUTLASS_DEVICE void + dma(MainloopPipeline pipeline, + PipelineState smem_pipe_write, + TensorA const& gA, TMA_LOAD_A& tma_load_a, + TensorB const& gB, TMA_LOAD_B& tma_load_b, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + char* shared_memory) + { + + using namespace cute; + int warp_idx = canonical_warp_idx(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + if (warp_idx_in_warp_group == 0 and lane_predicate) { + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + dim3 cluster_local_block_id = cute::block_id_in_cluster(); + auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (std::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (std::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Issue the prologue loads + int k_tile_prologue = min(k_tile_count, K_PIPE_MAX); + CUTLASS_PRAGMA_UNROLL + for (int count = 0; count < k_tile_prologue; ++count) { + pipeline.producer_acquire(smem_pipe_write); + int write_stage = smem_pipe_write.index(); + using BarrierType = typename MainloopPipeline::ValueType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(write_stage); + + copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + ++smem_pipe_write; + } + k_tile_count -= k_tile_prologue; + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + int write_stage = smem_pipe_write.index(); + using BarrierType = typename MainloopPipeline::ValueType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(write_stage); + + copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + dma_epilogue(MainloopPipeline pipeline, + PipelineState smem_pipe_write) + { + int warp_idx = canonical_warp_idx(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (warp_idx_in_warp_group == 0 and lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + for (int count = 0; count < K_PIPE_MAX; ++count) { + pipeline.producer_acquire(smem_pipe_write); + ++smem_pipe_write; + } + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + char* shared_memory, + Params const& mainloop_params + ) + { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(std::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(std::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + pipeline.consumer_wait(smem_pipe_read); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); // (V,M,K) x (V,N,K) => (V,M,N) + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + pipeline.consumer_wait(smem_pipe_read); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum); + warpgroup_arrive(); + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); // (V,M,K) x (V,N,K) => (V,M,N) + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + warpgroup_fence_operand(accum); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 657488c564..66884fb26b 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -36,20 +36,369 @@ #pragma once +// common #include "cutlass/cutlass.h" +#include "cutlass/trace.h" +#include "cutlass/cluster_launch.hpp" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" + +// 2.x #include "cutlass/gemm/device/gemm_universal_base.h" #include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +// 3.x +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::device { //////////////////////////////////////////////////////////////////////////////// -namespace cutlass { -namespace gemm { -namespace device { +/*! + GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel + of type cutlass::gemm::kernel::Gemm or cutlass::gemm::kernel::GemmUniversal. + + It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs + to create it from the host facing arguments. For power users, new static methods + are exposed in 3.x APIs that bypass the stateful methods or args->params lowering. + + It supports kernel types that implement both the 2.x and 3.0 APIs, + however, this is done by specializing the implementation of GemmUniversalAdapter + on the two kernel API types, and thus, GemmUniversalAdapter's behaviour might + differ between the two specializations. +*/ +template +class GemmUniversalAdapter; + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalAdapter< + GemmKernel_, + std::enable_if_t::value>> +{ +public: + using GemmKernel = GemmKernel_; + using TileShape = typename GemmKernel::TileShape; + using ElementA = typename GemmKernel::ElementA; + using ElementB = typename GemmKernel::ElementB; + using ElementC = typename GemmKernel::ElementC; + using ElementAccumulator = typename GemmKernel::TiledMma::ValTypeC; + using DispatchPolicy = typename GemmKernel::DispatchPolicy; + using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; + using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; + + // Map back to 2.x type as best as possible + using LayoutA = gemm::detail::StrideToLayoutTagA_t; + using LayoutB = gemm::detail::StrideToLayoutTagB_t; + using LayoutC = gemm::detail::StrideToLayoutTagC_t; + using LayoutD = gemm::detail::StrideToLayoutTagC_t; + + // NOTE: 3.0 kernels do not support complex transforms for now ... + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 + using MathOperator = cutlass::arch::OpMultiplyAdd; + + // If our TiledMMA's instruction thread layout size is larger than 1, we know its a tensorop! + using OperatorClass = std::conditional_t< + (cute::size(typename GemmKernel::TiledMma::AtomThrID{}) > 1), + cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>; + + using ArchTag = typename GemmKernel::ArchTag; + + // NOTE: Assume identity swizzle for now + static_assert(std::is_void_v, + "CUTLASS 3.x kernel types do not support grid swizzle functors yet."); + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape + using ThreadblockShape = cutlass::gemm::GemmShape< + cute::size<0>(TileShape{}), + cute::size<1>(TileShape{}), + cute::size<2>(TileShape{})>; + + using ClusterShape = cutlass::gemm::GemmShape< + cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; + + // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape + using InstructionShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; + + // Legacy: provide a correct warp count, but no reliable warp shape + static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; + + // Warp shape is not a primary API type in 3.x + // But we can best approximate it by inspecting the TiledMma::TiledShape_MNK + // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K + // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads + static constexpr int WarpsInMma = std::max(4, cute::size(typename GemmKernel::TiledMma{}) / 32); + static constexpr int WarpsInMmaM = 4; + static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); + using WarpCount = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaM, + cute::size<1>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaN, + cute::size<2>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{})>; + + static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; + + // Inspect TiledCopy for A and B to compute the alignment size + static int constexpr kAlignmentA = gemm::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, ElementA>(); + static int constexpr kAlignmentB = gemm::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, ElementB>(); + + // NOTE: 3.0 DefaultEpilogues don't support vectorized stores (yet) + static int constexpr kAlignmentC = 1; + static int constexpr kAlignmentD = 1; + + using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + // Split-K preserves splits that are 128b aligned + static int constexpr kSplitKAlignment = std::max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); + + /// Argument structure: User API + using Arguments = typename GemmKernel::Arguments; + /// Argument structure: Kernel API + using Params = typename GemmKernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + +public: + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (GemmKernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } -///////////////////////////////////////////////////////////////////////////////////////////////// + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{})); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Arguments const& args) { + auto tmp_params = GemmKernel::to_underlying_arguments(args); + return GemmKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return GemmKernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("GemmUniversal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = GemmKernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + GemmKernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + size_t workspace_bytes = GemmKernel::get_workspace_size(args); + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + if (workspace_bytes) { + if (!workspace) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + } + + // Initialize the Params structure + params_ = GemmKernel::to_underlying_arguments(args, workspace); + + // account for dynamic smem capacity if needed + int smem_size = GemmKernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = GemmKernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversal::run()"); + dim3 constexpr block = GemmKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = GemmKernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(GemmKernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + device_kernel<<>>(params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) const { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 2.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// template -class GemmUniversalAdapter { +class GemmUniversalAdapter< + GemmKernel_, + std::enable_if_t::value>> +{ public: using GemmKernel = GemmKernel_; @@ -193,10 +542,8 @@ class GemmUniversalAdapter { } }; -///////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// -} // namespace device -} // namespace gemm -} // namespace cutlass +} // namespace cutlass::gemm::device -///////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp new file mode 100644 index 0000000000..a2cd9a1117 --- /dev/null +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -0,0 +1,144 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/arch/arch.h" + +#include "cute/layout.hpp" +#include "cute/numeric/integral_constant.hpp" + +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm { +using namespace cute; + +////////////////////////////////////////////////////////////////////////////// + +// +// Policies for categorical dispatch of mainloop against kernel grid schedules +// +struct KernelMultistage { }; +struct KernelTma { }; +struct KernelTmaWarpSpecialized { }; +struct KernelTmaWarpSpecializedPersistent { }; + +// +// Collective Mainloop Policies +// + +// 2 stage pipeline through 1 stage in smem, 1 in rmem, WITHOUT predicated gmem loads +struct MainloopSm70TwoStageUnpredicated { + constexpr static int Stages = 2; + using ArchTag = arch::Sm70; + using Schedule = KernelMultistage; + using ClusterShape = Shape<_1,_1,_1>; +}; + +// 2 stage pipeline through 1 stage in smem, 1 in rmem, with predicated gmem loads +struct MainloopSm70TwoStage { + constexpr static int Stages = 2; + using ArchTag = arch::Sm70; + using Schedule = KernelMultistage; + using ClusterShape = Shape<_1,_1,_1>; +}; + +// n-buffer in smem (cp.async), pipelined with registers, WITHOUT predicated gmem loads +template +struct MainloopSm80CpAsyncUnpredicated { + constexpr static int Stages = Stages_; + using ArchTag = arch::Sm80; + using Schedule = KernelMultistage; + using ClusterShape = Shape<_1,_1,_1>; +}; + +// n-buffer in smem (cp.async), pipelined with registers, with predicated gmem loads +template +struct MainloopSm80CpAsync { + constexpr static int Stages = Stages_; + using ArchTag = arch::Sm80; + using Schedule = KernelMultistage; + using ClusterShape = Shape<_1,_1,_1>; +}; + +// n-buffer in smem (cp.async), pipelined with Hopper GMMA, WITHOUT predicated gmem loads +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm90CpAsyncGmmaUnpredicated { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelMultistage; +}; + +// n-buffer in smem (cp.async), pipelined with Hopper GMMA, with predicated gmem loads +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm90CpAsyncGmma { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelMultistage; +}; + +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, static schedule between TMA and GMMA +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + int PipelineAsyncMmaStages_ = 1 +> +struct MainloopSm90TmaGmma { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + constexpr static int PipelineAsyncMmaStages = PipelineAsyncMmaStages_; + using ArchTag = arch::Sm90; + using Schedule = KernelTma; +}; + +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelTmaWarpSpecialized +> +struct MainloopSm90TmaGmmaWarpSpecialized { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; +}; + +////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm diff --git a/include/cutlass/gemm/gemm.h b/include/cutlass/gemm/gemm.h index 96a08de3f4..4b76101b28 100644 --- a/include/cutlass/gemm/gemm.h +++ b/include/cutlass/gemm/gemm.h @@ -35,6 +35,9 @@ #include "cutlass/cutlass.h" #include "cutlass/coord.h" +#include "cutlass/layout/matrix.h" +#include "cute/layout.hpp" +#include "cute/arch/copy_sm90.hpp" namespace cutlass { namespace gemm { @@ -420,6 +423,151 @@ enum class SharedMemoryClearOption { //////////////////////////////////////////////////////////////////////////////////////////////////// +// For each cutlass::layout, provides its corresponding cute stride types, 64b by default + +template +struct TagToStrideA {}; + +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using type = cute::Stride, int64_t>; + using tag = layout::RowMajor; +}; + +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using type = cute::Stride, int64_t, int64_t>; + using tag = layout::ColumnMajor; +}; + +template +struct TagToStrideB {}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using type = cute::Stride, int64_t, int64_t>; + using tag = layout::RowMajor; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using type = cute::Stride, int64_t>; + using tag = layout::ColumnMajor; +}; + + +// Maps to modes [N, N, L] +template +struct TagToStrideC : TagToStrideA { }; + +// Convenience aliases +template +using TagToStrideA_t = typename TagToStrideA::type; + +template +using TagToStrideB_t = typename TagToStrideB::type; + +template +using TagToStrideC_t = typename TagToStrideC::type; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// For 2.x compatibility APIs, provide stride->layout tag mappers + +namespace detail { + +// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices +template +constexpr +auto +stride_to_layout_tag_A() { + // Account for stride types with and without batch mode and batch modes with static zero stride + if constexpr (cute::size<0>(StrideAC{}) == 1) { // M major + return layout::ColumnMajor{}; + } + else { // K major + return layout::RowMajor{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +constexpr +auto +stride_to_layout_tag_B() { + // Account for stride types with and without batch mode and batch modes with static zero stride + if constexpr (cute::size<0>(StrideB{}) == 1) { // N major + return layout::RowMajor{}; + } + else { // K major + return layout::ColumnMajor{}; + } + + CUTE_GCC_UNREACHABLE; +} + +// Inspects a TiledCopy and returns its alignment in terms of element count +template +constexpr int +get_alignment_count_from_gmem_tiled_copy() { + // For TMA tiled copies, we know the alignment has to be 128 bits + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + return 128 / sizeof_bits::value; + } + else + { + // For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN + return GmemTiledCopy::NumValSrc; + } +} + +// Utilities to map Stride back on to their corresponding layout tags +template +struct StrideToLayoutTagA { + using type = decltype(detail::stride_to_layout_tag_A()); +}; + +template +struct StrideToLayoutTagB { + using type = decltype(detail::stride_to_layout_tag_B()); +}; + +// Maps to modes [N, N, L] +template +struct StrideToLayoutTagC : StrideToLayoutTagA { }; + +// Convenience aliases +template +using StrideToLayoutTagA_t = typename StrideToLayoutTagA::type; + +template +using StrideToLayoutTagB_t = typename StrideToLayoutTagB::type; + +template +using StrideToLayoutTagC_t = typename StrideToLayoutTagC::type; + +/////////////////////////////////////////////////////////////////////////////// + +// The following two metafunctions are used to detect whether a `kernel::Gemm` or `kernel::GemmUniversal` +// is implementing the CUTLASS 3.x API or not, by checking if the problem shape type is aliased within or not. +template +struct IsCutlass3GemmKernel : std::false_type { }; + +template +struct IsCutlass3GemmKernel> + : std::true_type { }; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////// + } // namespace gemm } // namespace cutlass diff --git a/include/cutlass/gemm/kernel/default_gemm.h b/include/cutlass/gemm/kernel/default_gemm.h index f6a312367c..4432008e65 100644 --- a/include/cutlass/gemm/kernel/default_gemm.h +++ b/include/cutlass/gemm/kernel/default_gemm.h @@ -262,8 +262,8 @@ struct DefaultGemm { - static_assert(platform::is_same::value - || platform::is_same>::value, + static_assert((platform::is_same::value + || platform::is_same>::value), "Epilogue in the kernel level must be row major"); /// Define the threadblock-scoped matrix multiply-accumulate @@ -714,8 +714,8 @@ struct DefaultGemm< PermuteDLayout, typename platform::enable_if< ! platform::is_same::value >::type > { - static_assert(platform::is_same::value - || platform::is_same>::value, + static_assert((platform::is_same::value + || platform::is_same>::value), "Epilogue in the kernel level must be row major"); /// Define the threadblock-scoped matrix multiply-accumulate @@ -841,8 +841,8 @@ struct DefaultGemm { - static_assert(platform::is_same::value - || platform::is_same>::value, + static_assert((platform::is_same::value + || platform::is_same>::value), "Epilogue in the kernel level must be row major"); /// Define the threadblock-scoped matrix multiply-accumulate diff --git a/include/cutlass/gemm/kernel/gemm.h b/include/cutlass/gemm/kernel/gemm.h index 1427acbb2b..b5064ec7cf 100644 --- a/include/cutlass/gemm/kernel/gemm.h +++ b/include/cutlass/gemm/kernel/gemm.h @@ -256,7 +256,7 @@ struct Gemm { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/gemm/kernel/gemm_array.h b/include/cutlass/gemm/kernel/gemm_array.h index 2e226a9748..1862e206fd 100644 --- a/include/cutlass/gemm/kernel/gemm_array.h +++ b/include/cutlass/gemm/kernel/gemm_array.h @@ -193,7 +193,7 @@ struct GemmArray { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_batched.h b/include/cutlass/gemm/kernel/gemm_batched.h index 489a899937..464aeef51d 100644 --- a/include/cutlass/gemm/kernel/gemm_batched.h +++ b/include/cutlass/gemm/kernel/gemm_batched.h @@ -204,7 +204,7 @@ struct GemmBatched { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_grouped.h b/include/cutlass/gemm/kernel/gemm_grouped.h index fd3d7f2f7f..84dc4aeec9 100644 --- a/include/cutlass/gemm/kernel/gemm_grouped.h +++ b/include/cutlass/gemm/kernel/gemm_grouped.h @@ -395,7 +395,7 @@ struct GemmGrouped { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_pipelined.h b/include/cutlass/gemm/kernel/gemm_pipelined.h index 93faa2cc15..df450d08c7 100644 --- a/include/cutlass/gemm/kernel/gemm_pipelined.h +++ b/include/cutlass/gemm/kernel/gemm_pipelined.h @@ -111,7 +111,7 @@ __global__ void GemmPipelined( tb_thread_id, tb_offset_B); - int warp_id = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_id = canonical_warp_idx(); int lane_id = threadIdx.x % 32; // diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex.h b/include/cutlass/gemm/kernel/gemm_planar_complex.h index a2c24b258d..7dbc5923f9 100644 --- a/include/cutlass/gemm/kernel/gemm_planar_complex.h +++ b/include/cutlass/gemm/kernel/gemm_planar_complex.h @@ -525,7 +525,7 @@ struct GemmPlanarComplex { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h index b990d6c298..21b801149a 100644 --- a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h +++ b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h @@ -467,7 +467,7 @@ struct GemmPlanarComplexArray { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/gemm/kernel/gemm_universal.h b/include/cutlass/gemm/kernel/gemm_universal.h index 7ddb76d57c..fc62c01bf3 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.h +++ b/include/cutlass/gemm/kernel/gemm_universal.h @@ -42,6 +42,8 @@ #include "cutlass/matrix_coord.h" #include "cutlass/complex.h" #include "cutlass/semaphore.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + #include "cutlass/layout/matrix.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/kernel/params_universal_base.h" @@ -61,7 +63,15 @@ template < typename Epilogue_, ///! Epilogue typename ThreadblockSwizzle_ ///! Threadblock swizzling function > -struct GemmUniversal { +class GemmUniversal< + Mma_, + Epilogue_, + ThreadblockSwizzle_, + void, + // 3.x kernels use the first template argument to define the ProblemShape tuple + // We use this invariant to SFINAE dispatch against either the 2.x API or the 3.x API + std::enable_if_t::value> +> { public: using Mma = Mma_; @@ -528,7 +538,7 @@ struct GemmUniversal { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp new file mode 100644 index 0000000000..cdac6ca488 --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -0,0 +1,72 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +/* + * Stateless universal device GEMM kernel type that treats GEMM as + * a composition of a collective mainloop and a collective epilogue. + * + * Supports both the 2.x and 3.x APIs based on whether the first type is + * a cute::tuple<> or not. + * 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h + * 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp + * + * In the following declaration, the name preceding the 'Or' refers to + * 3.x API type argument order, and the name succeeding the 'Or' refers to + * 2.x API type argument order. Template arguments without two names + * belong to the 3.x API only. +**/ +template < + class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l) + class CollectiveMainloopOrEpilogue_, + class CollectiveEpilogueOrThreadblockSwizzle_, + class GridSwizzle_ = void, + class Enable = void +> +class GemmUniversal; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel + +//////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/gemm/kernel/sm70_gemm.hpp" +#include "cutlass/gemm/kernel/sm90_gemm_tma.hpp" +#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp" +#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp" +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h index 7ab9d13968..8f67bd4577 100644 --- a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h +++ b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h @@ -918,7 +918,7 @@ struct GemmWithFusedEpilogue { lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt) { CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); CUTLASS_TRACE_HOST(" ldr: " << this->ldr); CUTLASS_TRACE_HOST(" ldt: " << this->ldt); @@ -1019,7 +1019,7 @@ struct GemmWithFusedEpilogue { batch_stride_Tensor(args.batch_stride_Tensor) { CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); CUTLASS_TRACE_HOST(" ldr: " << this->ldr); CUTLASS_TRACE_HOST(" ldt: " << args.ldt); @@ -1222,7 +1222,7 @@ struct GemmWithFusedEpilogue { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_with_k_reduction.h b/include/cutlass/gemm/kernel/gemm_with_k_reduction.h index 5145fb5db9..8e00e184d5 100644 --- a/include/cutlass/gemm/kernel/gemm_with_k_reduction.h +++ b/include/cutlass/gemm/kernel/gemm_with_k_reduction.h @@ -505,7 +505,7 @@ struct GemmWithKReduction { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/params_universal_base.h b/include/cutlass/gemm/kernel/params_universal_base.h index 1e77ea9c99..453379d448 100644 --- a/include/cutlass/gemm/kernel/params_universal_base.h +++ b/include/cutlass/gemm/kernel/params_universal_base.h @@ -189,15 +189,16 @@ struct UniversalParamsBase void *workspace, cudaStream_t stream = nullptr) { + semaphore = static_cast(workspace); // Zero-initialize entire workspace - if (workspace) + if (semaphore) { size_t workspace_bytes = get_workspace_size(); CUTLASS_TRACE_HOST(" Initialize " << workspace_bytes << " workspace bytes"); cudaError_t result = cudaMemsetAsync( - workspace, + semaphore, 0, workspace_bytes, stream); @@ -208,7 +209,6 @@ struct UniversalParamsBase } } - semaphore = static_cast(workspace); return Status::kSuccess; } diff --git a/include/cutlass/gemm/kernel/rank_2k_grouped.h b/include/cutlass/gemm/kernel/rank_2k_grouped.h index b93ecb2df2..1c840e7aff 100644 --- a/include/cutlass/gemm/kernel/rank_2k_grouped.h +++ b/include/cutlass/gemm/kernel/rank_2k_grouped.h @@ -525,7 +525,7 @@ struct Rank2KGrouped { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/rank_2k_universal.h b/include/cutlass/gemm/kernel/rank_2k_universal.h index c1ae5d33bb..6d1f4ac2ff 100644 --- a/include/cutlass/gemm/kernel/rank_2k_universal.h +++ b/include/cutlass/gemm/kernel/rank_2k_universal.h @@ -450,7 +450,7 @@ struct Rank2KUniversal { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/rank_k_universal.h b/include/cutlass/gemm/kernel/rank_k_universal.h index 3eaf595bf4..b7d1ad1958 100644 --- a/include/cutlass/gemm/kernel/rank_k_universal.h +++ b/include/cutlass/gemm/kernel/rank_k_universal.h @@ -403,7 +403,7 @@ struct RankKUniversal { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/sm70_gemm.hpp b/include/cutlass/gemm/kernel/sm70_gemm.hpp new file mode 100644 index 0000000000..efe51e23c9 --- /dev/null +++ b/include/cutlass/gemm/kernel/sm70_gemm.hpp @@ -0,0 +1,252 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/tensor.hpp" + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class GridSwizzle_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + GridSwizzle_, + std::enable_if_t>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + using GridSwizzle = GridSwizzle_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using MainloopParams = typename CollectiveMainloop::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueParams = typename CollectiveEpilogue::Params; + static_assert(std::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); + + static constexpr int SharedStorageSize = cute::max( + sizeof(typename CollectiveMainloop::SharedStorage), + sizeof(typename CollectiveEpilogue::SharedStorage)); + + static constexpr uint32_t MaxThreadsPerBlock = cute::size(TiledMma{}); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + EpilogueParams epilogue_params{}; + KernelHardwareInfo hw_info; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + EpilogueParams epilogue; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args, workspace), + CollectiveEpilogue::to_underlying_arguments(args, workspace) + }; + } + + static + bool + can_implement(Arguments const& args) { + return args.mode == GemmUniversalMode::kGemm or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + } + + static + int + get_workspace_size(Arguments const& args) { + return 0; + } + + static constexpr + dim3 + get_grid_shape(Params const& params) { + int batch_count = 1; + if constexpr (rank(ProblemShape{}) == 4) { + batch_count = cute::size<3>(params.problem_shape); + } + + return dim3( + cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))), + cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))), + batch_count + ); + } + + static constexpr + dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + + // Separate out problem shape for convenience + // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Preconditions + static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + int thread_idx = int(threadIdx.x); + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + auto [m_coord, n_coord, l_coord] = blockIdx; + auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); // (m,n,k,l) + + // Represent the full tensors + Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA); //(m,k,l) + Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB); //(n,k,l) + + // Get batch slice + Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k) + Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k) + + // Slice to get the tiles this thread block is responsible for + Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + + // Compute tile residues for predication + auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord + auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord + auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + + // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape + TiledMma tiled_mma; + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + clear(accumulators); + + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + int k_tile_count = size<2>(gA); + + // Perform the collective scoped MMA + CollectiveMainloop collective_mma; + collective_mma( + accumulators, + gA, + gB, + accumulators, + k_tile_iter, k_tile_count, + residue_mnk, + thread_idx, + smem_buf + ); + + // Epilogue and write to gD + CollectiveEpilogue epilogue{params.epilogue}; + epilogue( + problem_shape_MNKL, + blk_shape, + blk_coord_mnkl, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + smem_buf + ); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp new file mode 100644 index 0000000000..bd82ed111e --- /dev/null +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp @@ -0,0 +1,301 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" + +#include "cute/tensor.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +namespace detail { + +// IF_SWAP_AB::value will be true only if: +// class T has member SwapAB and T::SwapAB is true +template +struct IF_SWAP_AB { static constexpr bool value = false; }; + +template +struct IF_SWAP_AB > +{ static constexpr bool value = T::SwapAB; }; + +} // namespace + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class GridSwizzle_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + GridSwizzle_, + std::enable_if_t>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + using GridSwizzle = GridSwizzle_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueParams = typename CollectiveEpilogue::Params; + static_assert(std::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); + + static constexpr int SharedStorageSize = cute::max( + sizeof(typename CollectiveMainloop::SharedStorage), + sizeof(typename CollectiveEpilogue::SharedStorage)); + + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + EpilogueParams epilogue_params{}; + KernelHardwareInfo hw_info; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + EpilogueParams epilogue; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::IF_SWAP_AB::value) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args, workspace), + CollectiveEpilogue::to_underlying_arguments(args, workspace) + }; + } + + CUTLASS_HOST_DEVICE static + bool + can_implement(Arguments const& args) { + return args.mode == GemmUniversalMode::kGemm or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + } + + static + int + get_workspace_size(Arguments const& args) { + return 0; + } + + // Computes the kernel launch grid shape based on runtime parameters + static constexpr + dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = ClusterShape{}; + auto tile_shape = TileShape{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl( + problem_shape_MNKL, tile_shape, cluster_shape); + } + + static constexpr + dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. + #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { + printf("ERROR : Arch conditional MMA instruction used without targetting sm90a compute capability. Aborting.\n"); + return; + } + #endif + + // Preconditions + static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + int thread_idx = int(threadIdx.x); + int warp_idx = canonical_warp_idx(); + int lane_predicate = cute::elect_one_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + // Separate out problem shape for convenience + // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Get the appropriate blocks for this thread block -- potential for thread block locality + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + auto blk_coord = make_coord(_,_,_); // (m,n,k) -- defer the slice + + // Make tiled views + Tensor gA_mkl = local_tile(mA_mkl, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Compute m_coord, n_coord, and l_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); + auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); + auto output_tile_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Slice with m_coord and n_coord + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape + TiledMma tiled_mma; + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + clear(accumulators); + + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + auto k_tile_count = size<2>(gA); + + // Perform the collective scoped MMA + CollectiveMainloop collective_mma; + collective_mma( + gA, params.mainloop.tma_load_a, + gB, params.mainloop.tma_load_b, + accumulators, + k_tile_iter, k_tile_count, + thread_idx, + smem_buf, + params.mainloop + ); + + constexpr int BLK_M_RANK = rank<0>(blk_shape); + bool m_oob = int(blockIdx.x) >= size<2>(gA_mkl); + auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return m_oob ? 0 : get(M) - get<0,i>(blk_shape) * get(m_coord); + })); + + constexpr int BLK_N_RANK = rank<1>(blk_shape); + bool n_oob = int(blockIdx.y) >= size<2>(gB_nkl); + auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return n_oob ? 0 : get(N) - get<1,i>(blk_shape) * get(n_coord); + })); + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); + + // Epilogue and write to gD + CollectiveEpilogue epilogue{params.epilogue}; + epilogue( + problem_shape_MNKL, + blk_shape, + output_tile_coord, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + smem_buf + ); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp new file mode 100644 index 0000000000..9fc719e2dc --- /dev/null +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp @@ -0,0 +1,351 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/pipeline.hpp" +#include "cute/tensor.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class GridSwizzle_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + GridSwizzle_, + std::enable_if_t>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + using GridSwizzle = GridSwizzle_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueParams = typename CollectiveEpilogue::Params; + static_assert(std::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); + + static constexpr int SharedStorageSize = cute::max( + sizeof(typename CollectiveMainloop::SharedStorage), + sizeof(typename CollectiveEpilogue::SharedStorage)); + + static constexpr uint32_t NumDmaWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 1; + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumDmaWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + EpilogueParams epilogue_params{}; + KernelHardwareInfo hw_info; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + EpilogueParams epilogue; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::IF_SWAP_AB::value) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args, workspace), + CollectiveEpilogue::to_underlying_arguments(args, workspace) + }; + } + + CUTLASS_HOST_DEVICE static + bool + can_implement(Arguments const& args) { + return args.mode == GemmUniversalMode::kGemm or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + } + + static + int + get_workspace_size(Arguments const& args) { + return 0; + } + + // Computes the kernel launch grid shape based on runtime parameters + static constexpr + dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = ClusterShape{}; + auto tile_shape = TileShape{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl( + problem_shape_MNKL, tile_shape, cluster_shape); + } + + static constexpr + dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. + #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { + printf("ERROR : Arch conditional MMA instruction used without targetting sm90a compute capability. Aborting.\n"); + return; + } + #endif + + enum class WarpGroupRole { + Producer = 0, + Consumer = 1, + }; + + int thread_idx = int(threadIdx.x); + int warp_idx = canonical_warp_idx(); + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + int lane_predicate = cute::elect_one_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + using Pipeline = typename CollectiveMainloop::MainloopPipeline; + + using PipelineParams = typename CollectiveMainloop::PipelineParams; + PipelineParams params_pipeline; + params_pipeline.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + if (warp_group_role == WarpGroupRole::Producer) { + params_pipeline.role = Pipeline::ThreadCategory::Producer; + } + else { + params_pipeline.role = Pipeline::ThreadCategory::Consumer; + } + params_pipeline.is_leader = warp_group_thread_idx == 0; + params_pipeline.num_consumers = NumThreadsPerWarpGroup; + + // Initialize pipeline and setup starting pipeline state for the collectives + Pipeline pipeline = CollectiveMainloop::make_pipeline(smem_buf, params_pipeline); + + auto cluster_wait_fn = [&] () { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + return [] () { cute::cluster_wait(); }; + } + else { + __syncthreads(); + return [] () {}; // do nothing + } + } (); + + // Preconditions + static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Separate out problem shape for convenience + // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Get the appropriate blocks for this thread block -- potential for thread block locality + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + auto blk_coord = make_coord(_,_,_); // (m,n,k) -- defer the slice + + // Make tiled views + Tensor gA_mkl = local_tile(mA_mkl, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Compute m_coord, n_coord, and l_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); + auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); + auto output_tile_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Slice with m_coord and n_coord + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + auto k_tile_count = size<2>(gA); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + // In a warp specialized kernel, CollectiveMainloop exposes data movement and compute operations separately + CollectiveMainloop collective_mainloop; + + if (warp_group_role == WarpGroupRole::Producer) { + // For the DMA (prologue) - we start with an opposite phase - since we skip all waits + // i.e., we know that the buffer is indeed empty + typename CollectiveMainloop::PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + collective_mainloop.dma( + pipeline, + smem_pipe_write, + gA, params.mainloop.tma_load_a, + gB, params.mainloop.tma_load_b, + k_tile_iter, k_tile_count, + thread_idx, + smem_buf + ); + // Update starting pipeline state for the next tile + smem_pipe_write.advance(k_tile_count); + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.dma_epilogue(pipeline, smem_pipe_write); + } + else if (warp_group_role == WarpGroupRole::Consumer) { + typename CollectiveMainloop::PipelineState smem_pipe_read; + TiledMma tiled_mma; + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + clear(accumulators); + + collective_mainloop.mma( + pipeline, + smem_pipe_read, + accumulators, + k_tile_count, + thread_idx, + smem_buf, + params.mainloop + ); + + constexpr int BLK_M_RANK = rank<0>(blk_shape); + bool m_oob = int(blockIdx.x) >= size<2>(gA_mkl); + auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return m_oob ? 0 : get(M) - get<0,i>(blk_shape) * get(m_coord); + })); + + constexpr int BLK_N_RANK = rank<1>(blk_shape); + bool n_oob = int(blockIdx.y) >= size<2>(gB_nkl); + auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return n_oob ? 0 : get(N) - get<1,i>(blk_shape) * get(n_coord); + })); + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); + + // Epilogue and write to gD + CollectiveEpilogue epilogue{params.epilogue}; + epilogue( + problem_shape_MNKL, + blk_shape, + output_tile_coord, + accumulators, + tiled_mma, + residue_mnk, + warp_group_thread_idx, + smem_buf + ); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp new file mode 100644 index 0000000000..498bfad436 --- /dev/null +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp @@ -0,0 +1,487 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/fast_math.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" + +#include "cute/tensor.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class GridSwizzle_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + GridSwizzle_, + std::enable_if_t>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + using GridSwizzle = GridSwizzle_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueParams = typename CollectiveEpilogue::Params; + static_assert(std::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); + + static constexpr uint32_t NumDmaWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 2; + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for DMA and MATH WGs + static constexpr uint32_t DmaRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + /* Order Sequence barrier with two stages: one for Mainloop and one for Epilogue */ + static constexpr uint32_t StagesPerMathWarpGroup = 2; + using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier< + StagesPerMathWarpGroup, NumMmaWarpGroups>; + + // Kernel level shared memory storage + struct SharedStorage { + using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; + using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; + using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; + + MainloopSharedStorage mainloop; + EpilogueSharedStorage epilogue; + alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order_barrier_storage; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + EpilogueParams epilogue_params{}; + KernelHardwareInfo hw_info; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + EpilogueParams epilogue; + KernelHardwareInfo hw_info; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::IF_SWAP_AB::value) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args, workspace), + CollectiveEpilogue::to_underlying_arguments(args, workspace), + {args.hw_info.device_id, sm_count} + }; + } + + CUTLASS_HOST_DEVICE static + bool + can_implement(Arguments const& args) { + bool implementable = args.mode == GemmUniversalMode::kGemm or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + + // Number of blocks per problem (without batch) must not exceed 2^31 for the persistent scheduler to calculate using FastDivmod + auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); + auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = + detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl(problem_shape_MNKL, TileShape{}, ClusterShape{}); + uint64_t problem_blocks = problem_blocks_m * problem_blocks_n * problem_blocks_l; + implementable = implementable && (problem_blocks < (uint64_t(1) << 31)); + + return implementable; + } + + static + int + get_workspace_size(Arguments const& args) { + return 0; + } + + // Computes the kernel launch grid shape based on runtime parameters + static constexpr + dim3 + get_grid_shape(Params const& params) { + int sm_count = params.hw_info.sm_count; + CUTLASS_TRACE_HOST("get_grid_shape(): Persistent schedule grid plan using SM count = " << sm_count); + + // Compute the total number of output tiles our problem has + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = + detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl(problem_shape_MNKL, TileShape{}, ClusterShape{}); + int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks_l; + + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + dim3 launch_grid(1, cute::size<1>(ClusterShape{}), 1); + + // The else path is generic, however, we can avoid some divs if we know Cluster size is 1 + if constexpr (size(ClusterShape{}) == 1) { + launch_grid.x = std::min(sm_count, problem_blocks_total); + } + else { + /* + * Optimal grid size calculation is based on + * GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU + * Hence, maximum SMs per GPC = 18 + */ + constexpr int max_sm_per_gpc = 18; + // Provided SM count could possibly be less than the assumed maximum SMs per GPC + int min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; + int max_blk_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % size(ClusterShape{})); + int blk_per_device = min_num_gpc * max_blk_occupancy_per_gpc; + + launch_grid.x = std::min( + blk_per_device / size<1>(ClusterShape{}), + problem_blocks_total / size<1>(ClusterShape{})); + } + + return launch_grid; + } + + static constexpr + dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. + #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { + printf("ERROR : Arch conditional MMA instruction used without targetting sm90a compute capability. Aborting.\n"); + return; + } + #endif + + // Preconditions + static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + enum class WarpGroupRole { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int warp_idx = canonical_warp_idx(); + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + int lane_predicate = cute::elect_one_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + using Pipeline = typename CollectiveMainloop::MainloopPipeline; + using PipelineParams = typename CollectiveMainloop::PipelineParams; + PipelineParams params_pipeline; + params_pipeline.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + if (warp_group_role == WarpGroupRole::Producer) { + params_pipeline.role = Pipeline::ThreadCategory::Producer; + } + else { + params_pipeline.role = Pipeline::ThreadCategory::Consumer; + } + params_pipeline.is_leader = warp_group_thread_idx == 0; + params_pipeline.num_consumers = NumThreadsPerWarpGroup; + + // Initialize pipeline and setup starting pipeline state for the collectives + Pipeline pipeline = CollectiveMainloop::make_pipeline(smem_buf, params_pipeline); + typename CollectiveMainloop::PipelineState collective_start_state_pipe; + + typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; + // DMA WG will not participate in these Ordered Barrier syncs + params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); + params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group + MathWarpGroupOrderBarrier math_wg_order_barrier(shared_storage.math_wg_order_barrier_storage, params_math_wg_order_barrier); + + auto cluster_wait_fn = [&] () { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + return [] () { cute::cluster_wait(); }; + } + else { + __syncthreads(); + return [] () {}; // do nothing + } + } (); + + // Separate out problem shape for convenience + // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Get the appropriate blocks for this thread block -- potential for thread block locality + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + auto blk_coord = make_coord(_,_,_); // (m,n,k) -- defer the slice + + // Slice to get the tiles this thread block is responsible for + Tensor gA_mkl = local_tile(mA_mkl, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Get iterations along k-dimension + auto k_tile_count = size<3>(gA_mkl); + + detail::PersistentTileSchedulerSm90 scheduler(problem_shape_MNKL, blk_shape, ClusterShape{}); + + if (warp_group_role == WarpGroupRole::Consumer1) { + /* Advance 2nd Math WG to the next work tile for the startup */ + scheduler.advance_to_next_work(); + /* Advance 2nd Math WG pipeline state to the end of 1st Math WG */ + collective_start_state_pipe.advance(k_tile_count); + } + auto work_tile_info = scheduler.get_current_work(); + + // Perform the collective scoped MMA + CollectiveMainloop collective_mainloop; + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) { + cutlass::arch::warpgroup_reg_dealloc(); + + // For the DMA (prologue) - we start with an opposite phase - since we skip all waits + // i.e., we know that the buffer is indeed empty + typename CollectiveMainloop::PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + while (work_tile_info.is_valid_tile) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Slice with our work tile coordinates to construct mainloop tensor views + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + + collective_mainloop.dma( + pipeline, + smem_pipe_write, + gA, params.mainloop.tma_load_a, + gB, params.mainloop.tma_load_b, + k_tile_iter, k_tile_count, + thread_idx, + reinterpret_cast(&shared_storage.mainloop) + ); + // Update starting pipeline state for the next tile + smem_pipe_write.advance(k_tile_count); + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.dma_epilogue(pipeline, smem_pipe_write); + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape + cutlass::arch::warpgroup_reg_alloc(); + + while (work_tile_info.is_valid_tile) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Slice with our work tile coordinates to construct mainloop tensor views + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + + TiledMma tiled_mma; + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + clear(accumulators); + + /* Order two Math WG's MMA one after the other, helps hide Epilogue */ + math_wg_order_barrier.wait(); + + collective_mainloop.mma( + pipeline, + collective_start_state_pipe, + accumulators, + k_tile_count, + thread_idx, + reinterpret_cast(&shared_storage.mainloop), + params.mainloop + ); + + /* Cue for next Math WG's MMA to start */ + math_wg_order_barrier.arrive(); + + /* Order two Math WG's Epilogue one after the other */ + math_wg_order_barrier.wait(); + + constexpr int BLK_M_RANK = rank<0>(blk_shape); + bool m_oob = int(work_tile_info.M_idx) >= size<2>(gA_mkl); + auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return m_oob ? 0 : get(M) - get<0,i>(blk_shape) * get(m_coord); + })); + + constexpr int BLK_N_RANK = rank<1>(blk_shape); + bool n_oob = int(work_tile_info.N_idx) >= size<2>(gB_nkl); + auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return n_oob ? 0 : get(N) - get<1,i>(blk_shape) * get(n_coord); + })); + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); + + // Epilogue and write to gD + CollectiveEpilogue epilogue{params.epilogue}; + epilogue( + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + residue_mnk, + warp_group_thread_idx, + reinterpret_cast(&shared_storage.epilogue) + ); + + /* Cue for next Math WG's Epilogue to start */ + math_wg_order_barrier.arrive(); + + // Update starting pipeline state for the next tile + collective_start_state_pipe.advance(k_tile_count * NumMmaWarpGroups); + + scheduler.advance_to_next_work(NumMmaWarpGroups); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + } // Consumer Warp Groups End + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp new file mode 100644 index 0000000000..496d5e0703 --- /dev/null +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp @@ -0,0 +1,133 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/fast_math.h" +#include "cute/layout.hpp" + +namespace cutlass::gemm::kernel::detail { + +/////////////////////////////////////////////////////////////////////////////// + +// Persistent Thread Block (TB) scheduler +class PersistentTileSchedulerSm90 { + // + // Data members + // + +private: + uint32_t blocks_per_problem_; + uint32_t current_work_linear_idx_; + uint32_t grid_blocks_total_; + + FastDivmod divmod_batch_; + FastDivmod divmod_grid_y_; + FastDivmod divmod_blk_m_; + + struct WorkTileInfo { + int32_t M_idx = 0; + int32_t N_idx = 0; + int32_t L_idx = 0; + uint32_t is_valid_tile = false; + }; + + // + // Methods + // + +public: + + template + CUTLASS_DEVICE + PersistentTileSchedulerSm90(ProblemShapeMNKL problem_shape_mnkl, TileShape tile_shape, ClusterShape cluster_shape) { + // We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic + static_assert(is_static::value); + static_assert(is_static::value); + + // Round up to nearest multiple of cluster dim along each mode + auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = get_tiled_blk_shape_mnl( + problem_shape_mnkl, tile_shape, cluster_shape); + + blocks_per_problem_ = problem_blocks_m * problem_blocks_n * problem_blocks_l; + current_work_linear_idx_ = (int(blockIdx.x) * int(gridDim.y)) + int(blockIdx.y); + grid_blocks_total_ = int(gridDim.x) * int(gridDim.y); + + // Pre-compute our fast div/mods for rasterization so we don't have to pay for DIVs + divmod_batch_ = FastDivmod(problem_blocks_m * problem_blocks_n); + divmod_grid_y_ = FastDivmod(size<1>(cluster_shape)); + divmod_blk_m_ = FastDivmod(problem_blocks_m); + } + + CUTLASS_DEVICE + WorkTileInfo + get_current_work() const { + // Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices + int work_idx_l, remainder; + divmod_batch_(work_idx_l, remainder, current_work_linear_idx_); + + int blk_per_grid_dim, dontcare; + divmod_grid_y_(blk_per_grid_dim, dontcare, remainder); + + int block_idx_m, block_idx_n; + divmod_blk_m_(block_idx_n, block_idx_m, blk_per_grid_dim); + int work_idx_m = block_idx_m; + int work_idx_n = (block_idx_n * gridDim.y) + blockIdx.y; + + return {work_idx_m, work_idx_n, work_idx_l, current_work_linear_idx_ < blocks_per_problem_}; + } + + CUTLASS_DEVICE + void + advance_to_next_work(uint32_t advance_count = 1) { + current_work_linear_idx_ += grid_blocks_total_ * advance_count; + } + + // Given the inputs, computes the total number of output blocks this problem will compute over + // Note that this is only the logical size of our grid, not the physical grid we will actually launch. + template + CUTLASS_HOST_DEVICE constexpr static + dim3 + get_tiled_blk_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape blk_shape, ClusterShape cluster_shape) { + // Across M and N is our Cluster tile, so we must round up the blocks to the nearest whole number of Cluster tiles + auto blk_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(blk_shape))); + auto blk_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(blk_shape))); + + // Round up to nearest multiple of cluster dim along each mode + int problem_blocks_m = round_up(blk_m, cute::size<0>(cluster_shape)); + int problem_blocks_n = round_up(blk_n, cute::size<1>(cluster_shape)); + + // Cluster tile does not span the batch mode, so no extra rounding up required for it + int problem_blocks_l = int(cute::size<3>(problem_shape_mnkl)); + return {uint32_t(problem_blocks_m), uint32_t(problem_blocks_n), uint32_t(problem_blocks_l)}; + } +}; + +} // namespace cutlass::gemm::kernel::detail diff --git a/include/cutlass/gemm/kernel/sparse_gemm.h b/include/cutlass/gemm/kernel/sparse_gemm.h index f7b2678111..eba95aad4c 100644 --- a/include/cutlass/gemm/kernel/sparse_gemm.h +++ b/include/cutlass/gemm/kernel/sparse_gemm.h @@ -277,7 +277,7 @@ struct SparseGemm { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/gemm/kernel/symm_universal.h b/include/cutlass/gemm/kernel/symm_universal.h index 4bab2cf939..47e7035abe 100755 --- a/include/cutlass/gemm/kernel/symm_universal.h +++ b/include/cutlass/gemm/kernel/symm_universal.h @@ -415,7 +415,7 @@ struct SymmUniversal { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/trmm_universal.h b/include/cutlass/gemm/kernel/trmm_universal.h index 69e5563de1..7ba223bbb4 100644 --- a/include/cutlass/gemm/kernel/trmm_universal.h +++ b/include/cutlass/gemm/kernel/trmm_universal.h @@ -380,7 +380,7 @@ struct TrmmUniversal { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h b/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h index 9d46a14153..995796796d 100644 --- a/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h @@ -44,7 +44,7 @@ #include "cutlass/matrix_shape.h" #include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm75.h" #include "cutlass/arch/mma_sm80.h" #include "cutlass/gemm/gemm.h" @@ -120,9 +120,9 @@ class MmaWithReductionTensorOp { /// Underlying matrix multiply operator (concept: arch::Mma) using ArchMmaOperator = typename Policy::Operator; - /// Indicates math operator + /// Indicates math operator using MathOperator = typename ArchMmaOperator::Operator; - + /// Architecture tag from underlying instruction using ArchTag = typename ArchMmaOperator::ArchTag; @@ -223,9 +223,9 @@ class MmaWithReductionTensorOp { /// Performs a warp-level matrix multiply-accumulate operation CUTLASS_DEVICE void operator()( - FragmentC &D, - TransformedFragmentA const &A, - TransformedFragmentB const &B, + FragmentC &D, + TransformedFragmentA const &A, + TransformedFragmentB const &B, FragmentC const &C, FragmentReduction &gemm_k_reduction ) const { @@ -236,9 +236,9 @@ class MmaWithReductionTensorOp { D = C; - MmaOperandA const *ptr_A = reinterpret_cast(&A); - MmaOperandB const *ptr_B = reinterpret_cast(&B); - MmaOperandC *ptr_D = reinterpret_cast(&D); + [[maybe_unused]] MmaOperandA const *ptr_A = reinterpret_cast(&A); + [[maybe_unused]] MmaOperandB const *ptr_B = reinterpret_cast(&B); + [[maybe_unused]] MmaOperandC *ptr_D = reinterpret_cast(&D); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) assert(0); @@ -258,7 +258,7 @@ class MmaWithReductionTensorOp { ptr_D[m + n_serpentine * MmaIterations::kRow]); if (!kReduceKForA && m == 0) { - #if 0 + #if 0 gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4]); gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 1]); gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 2]); @@ -306,12 +306,12 @@ class MmaWithReductionTensorOp { } if (kReduceKForA && (n == 0)) { - #if 0 + #if 0 gemm_k_reduction[m * 2] += float(A[m * 8]); gemm_k_reduction[m * 2] += float(A[m * 8 + 1]); gemm_k_reduction[m * 2] += float(A[m * 8 + 4]); gemm_k_reduction[m * 2] += float(A[m * 8 + 5]); - + gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 2]); gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 3]); gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 6]); @@ -411,9 +411,9 @@ class MmaWithReductionTensorOp { Array * ptr_dst_B = reinterpret_cast *>(&dst_B); - + dst_A = convert_A(A); - + ptr_dst_B[0] = convert_B(ptr_B[0]); ptr_dst_B[1] = convert_B(ptr_B[1]); @@ -429,9 +429,9 @@ class MmaWithReductionTensorOp { Array * ptr_dst_A = reinterpret_cast *>(&dst_A); - + dst_B = convert_B(B); - + ptr_dst_A[0] = convert_A(ptr_A[0]); ptr_dst_A[1] = convert_A(ptr_A[1]); #else diff --git a/include/cutlass/kernel_hardware_info.hpp b/include/cutlass/kernel_hardware_info.hpp new file mode 100644 index 0000000000..3ae09324c5 --- /dev/null +++ b/include/cutlass/kernel_hardware_info.hpp @@ -0,0 +1,71 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cuda_runtime.h" + +#include "cutlass/trace.h" + +namespace cutlass { + +struct KernelHardwareInfo { + // + // Data members + // + int device_id = 0; + int sm_count = 0; + + // + // Methods + // + + static int + query_device_multiprocessor_count(int device_id = 0) { + cudaError_t result = cudaGetDevice(&device_id); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST( + " cudaGetDevice() returned error " + << cudaGetErrorString(result)); + return 0; + } + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_id); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST( + " cudaGetDeviceProperties() returned error " + << cudaGetErrorString(result)); + return 0; + } + return properties.multiProcessorCount; + } +}; + +} // namespace cutlass diff --git a/include/cutlass/layout/matrix.h b/include/cutlass/layout/matrix.h index 51100f40c0..fe7a848934 100644 --- a/include/cutlass/layout/matrix.h +++ b/include/cutlass/layout/matrix.h @@ -39,6 +39,8 @@ */ #pragma once +#include "cute/layout.hpp" + #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/matrix_coord.h" @@ -143,6 +145,15 @@ class RowMajor { LongIndex capacity(MatrixCoord const &extent) const { return LongIndex(extent.row()) * LongIndex(stride_[0]); } + + CUTLASS_HOST_DEVICE + cute::Layout, cute::Stride > > + to_cute_layout(MatrixCoord const &extent) const { + return cute::Layout, cute::Stride > >{ + {extent[0], extent[1]}, + {stride(0), cute::Int<1>{}} + }; + } }; /// Mapping function for column-major matrices. @@ -236,6 +247,15 @@ class ColumnMajor { LongIndex capacity(MatrixCoord const &extent) const { return LongIndex(extent.column()) * LongIndex(stride_[0]); } + + CUTLASS_HOST_DEVICE + cute::Layout, cute::Stride< cute::Int<1>, int64_t> > + to_cute_layout(MatrixCoord const &extent) const { + return cute::Layout, cute::Stride, int64_t> >{ + {extent[0], extent[1]}, + {cute::Int<1>{}, stride(0)} + }; + } }; /// Mapping function for interleaved matrices. Matrix is structured diff --git a/include/cutlass/pipeline.hpp b/include/cutlass/pipeline.hpp new file mode 100644 index 0000000000..67538aea17 --- /dev/null +++ b/include/cutlass/pipeline.hpp @@ -0,0 +1,529 @@ +/*************************************************************************************************** + * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are not permit- + * ted. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/numeric/integral_constant.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/barrier.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +using namespace arch; +using namespace cute; + +// Circular Buffer Index + Associated Phase +// Assumes only one operation possible - i.e., ++ +template +struct PipelineState { + + static constexpr uint32_t Stages = Stages_; + +private: + int index_ = 0; + uint32_t phase_ = 0; + +public: + CUTLASS_DEVICE + PipelineState(): index_{}, phase_{} {} + + CUTLASS_DEVICE + PipelineState(int index, uint32_t phase) + : index_(index) + , phase_(phase){} + + CUTLASS_DEVICE + int index() const { + return index_; + } + + CUTLASS_DEVICE + uint32_t phase() const { + return phase_; + } + + CUTLASS_DEVICE + void operator++() { + ++index_; + if (index_ == Stages) { + index_ = 0; + phase_ ^= 1; + } + } + + CUTLASS_DEVICE + PipelineState& operator=(const PipelineState& other) { + index_ = other.index(); + phase_ = other.phase(); + return *this; + } + + CUTLASS_DEVICE + PipelineState advance(uint32_t num_iterations) { + // Number of iterations cross over the stage boundary => flipped phase + if ((num_iterations < Stages) && (index_ + num_iterations) >= Stages ) { + phase_ ^= 1; + } + // How many times number of iterations cross over the stage boundary and + // end up on a odd number => flipped phase + if ((num_iterations >= Stages) && (((index_ + num_iterations) / Stages) % 2) == 1) { + phase_ ^= 1; + } + index_ = (index_ + num_iterations) % Stages; + return *this; + } + + CUTLASS_DEVICE + static PipelineState make_pipeline_state(PipelineState start_state, uint32_t num_iterations) { + return start_state.advance(num_iterations); + } +}; + +template +CUTLASS_DEVICE +PipelineState make_producer_start_state() +{ + // Producer starts with an opposite phase as the buffer are initially empty + constexpr int InitialProducerStage = 0; + constexpr uint32_t InitialProducerPhase = 1; + return {InitialProducerStage, InitialProducerPhase}; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMA (producer) Async Pipeline class +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Assumptions : Constructor is Visible Cluster-wide (as it needs a Cluster-Sync) +// We have exactly one thread elected in the Producer as the "leader" +// Currently, it is optional to elect a leader for the Consumers +template +class PipelineTmaAsync { +public : + using ClusterShape = ClusterShape_; + using FullBarrier = ClusterTransactionBarrier; + using EmptyBarrier = ClusterBarrier; + using ValueType = FullBarrier::ValueType; + static constexpr uint32_t Stages = Stages_; + + struct SharedStorage { + FullBarrier full_barrier_[Stages]; + EmptyBarrier empty_barrier_[Stages]; + }; + + enum class ThreadCategory { + NonParticipant, + Producer, + Consumer, + ProducerConsumer + }; + + struct Params { + uint32_t transaction_bytes = 0; + ThreadCategory role = ThreadCategory::NonParticipant; + uint32_t is_leader = 0; + uint32_t num_consumers = 0; + }; + +private : + // + // Data Members + // + uint32_t dst_blockid_ = 0; + uint32_t is_signalling_thread_ = 0; + FullBarrier *full_barrier_ptr_ = nullptr; + EmptyBarrier *empty_barrier_ptr_ = nullptr; + Params params_; + + // + // Methods + // + +public: + // Constructor + CUTLASS_DEVICE + PipelineTmaAsync(SharedStorage& storage, Params params) + : params_(params) + , full_barrier_ptr_(&storage.full_barrier_[0]) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) { + + int warp_idx = canonical_warp_idx(); + int lane_predicate = cute::elect_one_sync(); + auto cluster_shape = ClusterShape{}; + + if (warp_idx == 0 && lane_predicate == 1) { + // Barrier FULL init + for (int i = 0; i < Stages; ++i) { + full_barrier_ptr_[i].init(1); + } + + // Barrier EMPTY init + uint32_t const num_consumers = cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1; + for (int i = 0; i < Stages; ++i) { + empty_barrier_ptr_[i].init(num_consumers); + } + } + + // Logic to optimally schedule Empty Arrives + // Goal : To divide SYNCS Empty Arrival duty equally amongst the Warp-Group (128 threads) + dim3 block_id = block_id_in_cluster(); + auto cluster_size = cute::size(cluster_shape); + static constexpr int MaxClusterSize = 16; + static_assert(cluster_size <= MaxClusterSize, "ERROR : Cluster size too large !" ); + + // STEP 1 : Use Cute Layout function to generate an optimal dst block-id (0-15) + if (params_.num_consumers == 128) { + int thread_idx = threadIdx.x % 128; + is_signalling_thread_ = (thread_idx % (128 / MaxClusterSize)) == 0; + auto layout = cute::composition(Swizzle<2,0,-2>{}, + Layout,Stride<_4, _1>>{}); + uint32_t thread_row = warp_idx % 4; + uint32_t thread_col = (thread_idx / 8) % 4; + dst_blockid_ = layout(thread_row, thread_col); + } + else if (params_.num_consumers == 32){ + int thread_idx = threadIdx.x % 32; + is_signalling_thread_ = (thread_idx % (32 / MaxClusterSize)) == 0; + auto layout = Layout,Stride<_4, _1>>{}; + uint32_t thread_row = thread_idx / 8; + uint32_t thread_col = (thread_idx % 8) / 2; + dst_blockid_ = layout(thread_row, thread_col); + } + else { + is_signalling_thread_ = 0; + } + + // STEP 2: Find if this dst block-id needs an arrival for this problem + is_signalling_thread_ &= dst_blockid_ < cluster_size; + is_signalling_thread_ &= is_same_row_or_col(dst_blockid_, block_id, cluster_shape); + + cutlass::arch::fence_barrier_init(); + } + + CUTLASS_DEVICE + void producer_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait = false) { + // 1. Wait for empty barrier to be ready + // 2. Set the transaction bytes set to occur on the Full barrier + uint32_t done = empty_barrier_ptr_[stage].test_wait(phase, (!skip_wait)); + if ((!done) && (!skip_wait)){ + empty_barrier_ptr_[stage].wait(phase); + } + + if (params_.is_leader) { + full_barrier_ptr_[stage].arrive_and_reset_bytes(params_.transaction_bytes); + } + + } + + CUTLASS_DEVICE + void producer_acquire(PipelineState state) { + producer_acquire(state.index(), state.phase()); + } + + // NOP for TMA based mainloop + CUTLASS_DEVICE + void producer_commit(uint32_t stage, uint32_t bytes) { + // Below code is used only for unit-testing (in the absennce of TMA commit) + #if CUTLASS_UNIT_TEST_PIPELINE + if (params_.is_leader) { + // STEP 1 : Commit to self + full_barrier_ptr_[stage].commit(bytes); + + // STEP 2 : Commit to other blocks in our cluster + auto cluster_shape = ClusterShape{}; + Layout block_layout_in_cluster = make_layout(cluster_shape); + dim3 local_block_id = cute::block_id_in_cluster(); + + CUTLASS_PRAGMA_UNROLL + for(int n = 0; n < size<1>(block_layout_in_cluster); ++n) { + uint32_t dst_block_id = block_layout_in_cluster(local_block_id.x,n,Int<0>{}); + full_barrier_ptr_[stage].commit(dst_block_id, bytes, n!=local_block_id.y); + } + + CUTLASS_PRAGMA_UNROLL + for(int m = 0; m < size<0>(block_layout_in_cluster); ++m) { + uint32_t dst_block_id = block_layout_in_cluster(m,local_block_id.y,Int<0>{}); + full_barrier_ptr_[stage].commit(dst_block_id, bytes, m!=local_block_id.x); + } + } + #endif + } + + CUTLASS_DEVICE + void producer_commit(PipelineState state, uint32_t bytes) { + producer_commit(state.index(), bytes); + } + + + // Wait for producer to commit transactions (done by TMA) + CUTLASS_DEVICE + void consumer_wait(uint32_t stage, uint32_t phase) { + uint32_t done = full_barrier_ptr_[stage].test_wait(phase); + if (!done){ + full_barrier_ptr_[stage].wait(phase); + } + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state) { + consumer_wait(state.index(), state.phase()); + } + + // Consumer signalling Producer of completion + // Ensures all blocks in the Same Row and Column get notifed. + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip = false) { + empty_barrier_ptr_[stage].arrive(dst_blockid_, is_signalling_thread_ & (!skip)); + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index()); + } + + CUTLASS_DEVICE + ValueType* producer_get_barrier(uint32_t stage) { + return reinterpret_cast(&full_barrier_ptr_[stage]); + } + + CUTLASS_DEVICE + bool is_same_row_or_col(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { + return ((dst_block_id % cute::size<0>(cluster_shape)) == block_id.x || + (dst_block_id / cute::size<0>(cluster_shape)) == block_id.y); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Simple producer-consumer async Pipeline class +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// *Count Signifies the number of producers / consumers who will announce their completion + +template +class PipelineAsync { +public : + using FullBarrier = ClusterBarrier; + using EmptyBarrier = ClusterBarrier; + using ProducerBarrierType = FullBarrier::ValueType; + static constexpr uint32_t Stages = Stages_; + + struct SharedStorage { + FullBarrier full_barrier_[Stages]; + EmptyBarrier empty_barrier_[Stages]; + }; + + enum class ThreadCategory { + NonParticipant, + Producer, + Consumer, + ProducerConsumer + }; + + struct Params { + ThreadCategory role = ThreadCategory::NonParticipant; + uint32_t producer_arv_count = 1; + uint32_t consumer_arv_count = 1; + uint32_t dst_blockid = cute::block_rank_in_cluster(); + }; + +private: + // + // Data Members + // + Params params_; + FullBarrier *full_barrier_ptr_; + EmptyBarrier *empty_barrier_ptr_; + +public: + + // Default assumption when only storage is passed is : + // => single producer, single consumer & they are in the same block (within the Cluster) + CUTLASS_DEVICE + PipelineAsync(SharedStorage& storage) + : PipelineAsync(storage, {}) {} + + CUTLASS_DEVICE + PipelineAsync( + SharedStorage& storage, + Params const& params) : + params_(params), + full_barrier_ptr_(&storage.full_barrier_[0]), + empty_barrier_ptr_(&storage.empty_barrier_[0]) { + + int warp_idx = canonical_warp_idx(); + int lane_predicate = cute::elect_one_sync(); + + // Barrier FULL, EMPTY init + // Init is done only by thread 0 of the block + if (warp_idx == 0 && lane_predicate == 1) { + for (int i = 0; i < Stages; ++i) { + full_barrier_ptr_[i].init(params.producer_arv_count); + empty_barrier_ptr_[i].init(params.consumer_arv_count); + } + } + + cutlass::arch::fence_barrier_init(); + } + + CUTLASS_DEVICE + void producer_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait = false) { + uint32_t done = empty_barrier_ptr_[stage].test_wait(phase, (!skip_wait)); + if ((!done) && (!skip_wait)){ + empty_barrier_ptr_[stage].wait(phase); + } + } + + CUTLASS_DEVICE + void producer_acquire(PipelineState state) { + producer_acquire(state.index(), state.phase()); + } + + CUTLASS_DEVICE + void producer_commit(uint32_t stage) { + full_barrier_ptr_[stage].arrive(); + } + + CUTLASS_DEVICE + void producer_commit(PipelineState state) { + producer_commit(state.index()); + } + + CUTLASS_DEVICE + void consumer_wait(uint32_t stage, uint32_t phase) { + uint32_t done = full_barrier_ptr_[stage].test_wait(phase); + if (!done){ + full_barrier_ptr_[stage].wait(phase); + } + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state) { + consumer_wait(state.index(), state.phase()); + } + + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip = false) { + empty_barrier_ptr_[stage].arrive(params_.dst_blockid, (not skip)); + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index()); + } + + CUTLASS_DEVICE + ProducerBarrierType* get_producer_barrier(uint32_t stage) { + return reinterpret_cast(&full_barrier_ptr_[stage]); + } +}; + + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Barrier to ensure an Ordered Sequence between +// SequenceLength number of groups (each with group_size participants) executing SequenceDepth Stages +// i.e., for all i < j - only after id "i" arrives at a particular stage "m" +// will the wait() for id "j" succeed for the same stage +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class OrderedSequenceBarrier { +public : + using Barrier = ClusterBarrier; + + struct SharedStorage { + Barrier barrier_[SequenceDepth][SequenceLength]; + }; + + struct Params { + uint32_t group_id; + uint32_t group_size; + }; + +private : + // + // Data Members + // + + // In future this Params object can be replaced easily with a CG object + Params params_; + Barrier *barrier_ptr_; + PipelineState stage_; + + static constexpr int Depth = SequenceDepth; + static constexpr int Length = SequenceLength; + +public: + OrderedSequenceBarrier() = delete; + OrderedSequenceBarrier(const OrderedSequenceBarrier&) = delete; + OrderedSequenceBarrier(OrderedSequenceBarrier&&) = delete; + OrderedSequenceBarrier& operator=(const OrderedSequenceBarrier&) = delete; + OrderedSequenceBarrier& operator=(OrderedSequenceBarrier&&) = delete; + ~OrderedSequenceBarrier() = default; + + CUTLASS_DEVICE + OrderedSequenceBarrier(SharedStorage& storage, Params const& params) : + params_(params), + barrier_ptr_(&storage.barrier_[0][0]), + // Group 0 - starts with an opposite phase + stage_({0, params.group_id == 0}) { + + int warp_idx = canonical_warp_idx(); + int lane_predicate = cute::elect_one_sync(); + + // Barrier FULL, EMPTY init + // Init is done only by the one elected thread of the block + if (warp_idx == 0 && lane_predicate == 1) { + for (int d = 0; d < Depth; ++d) { + for (int l = 0; l < Length; ++l) { + barrier_ptr_[d * Length + l].init(params.group_size); + } + } + } + + cutlass::arch::fence_barrier_init(); + } + + // Wait on a stage to be unlocked + CUTLASS_DEVICE + void wait() { + get_barrier_for_current_stage(params_.group_id).wait(stage_.phase()); + } + + // Signal completion of Stage and move to the next stage + // (group_id) signals to (group_id+1) + CUTLASS_DEVICE + void arrive() { + int signalling_id = (params_.group_id + 1) % Length; + get_barrier_for_current_stage(signalling_id).arrive(); + ++stage_; + } + +private: + + CUTLASS_DEVICE + Barrier& get_barrier_for_current_stage(int group_id) { + return barrier_ptr_[stage_.index() * Length + group_id]; + } +}; + +} // end namespace cutlass diff --git a/include/cutlass/quaternion.h b/include/cutlass/quaternion.h index d62d1c6274..1015be4bf0 100644 --- a/include/cutlass/quaternion.h +++ b/include/cutlass/quaternion.h @@ -745,7 +745,6 @@ struct multiply_add, Quaternion, Quaternion> { } }; - ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/transform/pitch_linear_thread_map.h b/include/cutlass/transform/pitch_linear_thread_map.h index 8ed0538c44..c084dd4870 100644 --- a/include/cutlass/transform/pitch_linear_thread_map.h +++ b/include/cutlass/transform/pitch_linear_thread_map.h @@ -29,7 +29,7 @@ * **************************************************************************************************/ /*! \file - \brief Templates implementing how threads are mapped to a given tile. + \brief Templates implementing how threads are mapped to a given tile. */ @@ -163,9 +163,9 @@ struct PitchLinearTilePolicyStripminedThreadContiguous using Iterations = layout::PitchLinearShape< Shape::kContiguous / (kThreads * kElementsPerAccess), - Shape::kStrided>; + Shape::kStrided>; - using Delta = layout::PitchLinearShape<1, 1>; + using Delta = layout::PitchLinearShape<1, 1>; CUTLASS_HOST_DEVICE static TensorCoord initial_offset(int thread_id) @@ -183,7 +183,7 @@ struct PitchLinearTilePolicyStripminedThreadStrided { static_assert((Shape::kStrided % Threads == 0), "Strided shape must divide number of threads"); - + using TensorCoord = layout::PitchLinearCoord; static int const kThreads = Threads; @@ -191,16 +191,16 @@ struct PitchLinearTilePolicyStripminedThreadStrided using Iterations = layout::PitchLinearShape< Shape::kContiguous / kElementsPerAccess, - Shape::kStrided / kThreads>; + Shape::kStrided / kThreads>; - using Delta = layout::PitchLinearShape<1, 1>; + using Delta = layout::PitchLinearShape<1, 1>; using ShapeVec = Shape; CUTLASS_HOST_DEVICE static TensorCoord initial_offset(int thread_id) { - + return TensorCoord(0, thread_id * Iterations::kStrided); } }; @@ -334,7 +334,7 @@ struct PitchLinearWarpRakedThreadMap { }; // This is the offset of a thread within a threadblock tile (units of vectors) - layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = + layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = warp_footprint * warp_offset + thread_offset_in_warp; // This is the offset of a thread within a threadblock tile (units of elements) @@ -460,7 +460,7 @@ struct PitchLinearStridedWarpRakedThreadMap { }; // This is the offset of a thread within a threadblock tile (units of vectors) - layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = + layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = warp_footprint * warp_offset + thread_offset_in_warp; // This is the offset of a thread within a threadblock tile (units of elements) @@ -601,7 +601,7 @@ struct TransposePitchLinearThreadMapSimt { static_assert(kElementsPerAccess == 1 , "Simt transpose requires elements per access to be 1"); ///< Iterations along each dimension (concept: PitchLinearShape) - using Iterations = + using Iterations = layout::PitchLinearShape; @@ -615,7 +615,7 @@ struct TransposePitchLinearThreadMapSimt { ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) using Delta = - layout::PitchLinearShape; @@ -693,12 +693,12 @@ struct PitchLinearWarpStripedThreadMap { // Divide it into the number of warps, first partitioning the strided dimension then the // contiguous. - static int const kWarpsStrided = - (WarpAccessIterations::kStrided >= kWarpCount + static int const kWarpsStrided = + (WarpAccessIterations::kStrided >= kWarpCount ? kWarpCount : (kWarpCount / WarpAccessIterations::kStrided)); - static int const kWarpsContiguous = - (kWarpCount > WarpAccessIterations::kStrided ? + static int const kWarpsContiguous = + (kWarpCount > WarpAccessIterations::kStrided ? WarpAccessIterations::kContiguous / kWarpsStrided : 1); /// Arrangement of warps within a threadblock-scoped tile @@ -752,7 +752,7 @@ struct PitchLinearWarpStripedThreadMap { }; // This is the offset of a thread within a threadblock tile (units of vectors) - layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = + layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = warp_footprint * warp_offset + thread_offset_in_warp; // This is the offset of a thread within a threadblock tile (units of elements) @@ -776,7 +776,7 @@ struct PitchLinearWarpStripedThreadMap { template < typename Shape_, int Threads, - typename ThreadTileShape + typename ThreadTileShape > struct PitchLinear2DThreadTileStripminedThreadMap; @@ -888,7 +888,7 @@ struct TransposePitchLinearThreadMap2DThreadTile { static_assert(kElementsPerAccess > 1 , "Simt transpose requires elements per access to be 1"); ///< Iterations along each dimension (concept: PitchLinearShape) - using Iterations = + using Iterations = layout::PitchLinearShape; @@ -899,7 +899,7 @@ struct TransposePitchLinearThreadMap2DThreadTile { ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) using Delta = - layout::PitchLinearShape; diff --git a/include/cutlass/uint128.h b/include/cutlass/uint128.h index df65623c66..38d5b4d587 100644 --- a/include/cutlass/uint128.h +++ b/include/cutlass/uint128.h @@ -54,7 +54,7 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// /// Optionally enable GCC's built-in type -#if (defined(__x86_64) || defined (__aarch64__)) && !defined(__CUDA_ARCH__) && defined(__GNUC__) +#if defined(__x86_64) && !defined(__CUDA_ARCH__) && defined(__GNUC__) #define CUTLASS_UINT128_NATIVE #elif defined(_MSC_VER) && defined(_M_AMD64) && !defined(__CUDA_ARCH__) #define CUTLASS_INT128_ARITHMETIC @@ -71,7 +71,7 @@ namespace cutlass { struct uint128_t { /// Size of one part of the uint's storage in bits - int const kPartSize = sizeof_bits::value; + static constexpr int kPartSize = sizeof_bits::value; struct hilo { uint64_t lo; @@ -158,7 +158,7 @@ struct uint128_t { /// Multiply by unsigned 64b integer yielding 128b integer CUTLASS_HOST_DEVICE uint128_t operator*(uint64_t const &rhs) const { - uint128_t y; + uint128_t y{}; #if defined(CUTLASS_UINT128_NATIVE) y.native = native * rhs; #elif defined(CUTLASS_INT128_ARITHMETIC) diff --git a/media/docs/code_organization.md b/media/docs/code_organization.md index 61ffbafe9c..53ffc84dfe 100644 --- a/media/docs/code_organization.md +++ b/media/docs/code_organization.md @@ -7,6 +7,7 @@ This document describes the layout of the CUTLASS repository. The main components are: * **CUTLASS Template Library** - CUDA Templates for Linear Algebra Subroutines and Solvers (header only) +* **CuTe Template Library** - CUTLASS's core vocabulary layout type and associated algebra (header only) * **CUTLASS Utilities** - Additional templates * **CUTLASS Instance Library** - instantiations of CUTLASS templates covering the design space * **CUTLASS Profiler** - CUTLASS Library, Profiler, and Utilities @@ -29,7 +30,6 @@ CUTLASS Templates are implemented by header files in the following directory str ``` include/ # Top-level include directory. Client applications should target this path. - cutlass/ # CUDA Templates for Linear Algebra Subroutines and Solvers - headers only arch/ # direct exposure of architecture features (including instruction-level GEMMs) @@ -37,10 +37,11 @@ include/ # Top-level include directory. Client applications gemm/ # code specialized for general matrix product computations thread/ # thread-level operators warp/ # warp-level operators + collective/ # 3.x API operators for all threads a tiled mma/copy are built over threadblock/ # CTA-level operators kernel/ # CUDA kernel entry points device/ # launches kernel(s) over a full device - * # scope-agnostic components and basic vocabular type definitions for GEMM + * # scope-agnostic components and basic vocabulary type definitions for GEMM layout/ # layout definitions for matrices, tensors, and other mathematical objects in memory * @@ -51,7 +52,7 @@ include/ # Top-level include directory. Client applications threadblock/ # CTA-level operators kernel/ # CUDA kernel entry points device/ # launches kernel(s) over a full device - * # scope-agnostic components and basic vocabular type definitions + * # scope-agnostic components and basic vocabulary type definitions transform/ # code specialized for layout, type, and domain transformations thread/ # thread-level operators @@ -64,11 +65,27 @@ include/ # Top-level include directory. Client applications util/ # miscellaneous CUTLASS components * * # core vocabulary types and fundamental arithmetic operators + + cute / # CuTe Layout, layout algebra, MMA/Copy atoms, tiled MMA/Copy + algorithm/ # Definitions of core operations such as copy, gemm, and operations on cute::tuples + arch/ # Bare bones PTX wrapper structs for copy and math instructions + atom/ # Meta-information either link to or built from arch/ operators + mma_atom.hpp # cute::Mma_Atom and cute::TiledMma + copy_atom.hpp # cute::Copy_Atom and cute::TiledCopy + *sm*.hpp # Arch specific meta-information for copy and math operations + container/ # Core container types used across CuTe, namely, cute::tuple + numeric/ # CuTe's internal numerics implementation + * # Core library types such as Shape, Stride, Layout, Tensor, and associated operations ``` See [Programming Guidelines](/media/docs/programming_guidelines.md) for further details about conventions and design patterns used throughout CUTLASS. +## CuTe + +CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly packages the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations. More documentation +for CuTe can be found in [`/media/docs/cute/`](/media/docs/cute/). + ## Tools The `tools/` directory contains clients of the CUTLASS Template library and includes the following. @@ -181,9 +198,9 @@ examples/ 11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes - 12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu + 12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu activation function - 13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel + 13_fused_two_gemms/ # example demonstrating two GEMMs fused into one kernel ``` ## Media diff --git a/media/docs/cute/00_quickstart.md b/media/docs/cute/00_quickstart.md new file mode 100644 index 0000000000..df7ceadc7e --- /dev/null +++ b/media/docs/cute/00_quickstart.md @@ -0,0 +1,75 @@ +# Getting Started With CuTe + +CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly packages the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations. + +The core abstraction of CuTe are the hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning. + +## System Requirements + +CuTe shares CUTLASS 3.0's software requirements, +including NVCC with a C++17 host compiler. + +## Knowledge prerequisites + +CuTe is a CUDA C++ library. It requires C++17 +(the revision of the C++ Standard that was released in 2017). + +Throughout this tutorial, we assume intermediate C++ experience. +For example, we assume that readers know +how to read and write templated functions and classes, and +how to use the `auto` keyword to deduce a function's return type. +We will be gentle with C++ and explain some things +that you might already know. + +We also assume intermediate CUDA experience. +For example, readers must know +the difference between device and host code, +and how to launch kernels. + +## Building Tests and Examples + +CuTe's tests and examples build and run as part of CUTLASS's normal build process. +CuTe's unit tests live in the [`test/unit/cute`](../../../test/unit/cute) subdirectory. +Its examples live in the [`examples/cute`](../../../examples/cute) subdirectory. + +## Library Organization + +CuTe is a header-only C++ library, so there is no source code that needs building. Library headers are contained within the top level [`include/cute`](../../../include/cute) directory, with components of the library grouped by directories that represent their semantics. + +| Directory | Contents | +|------------------------|------------------------| +| [`include/cute`](../../../include/cute) | Each header in the top level corresponds to one of the fundamental building blocks of CuTe, such as [`Layout`](../../../include/cute/layout.hpp) or [`Tensor`](../../../include/cute/tensor.hpp). | +| [`include/cute/container`](../../../include/cute/container) | Implementations of STL-like container objects, such as tuple, array, aligned array, and array views. | +| [`include/cute/numeric`](../../../include/cute/numeric) | Templates that handle nonstandard floating-point types, unsigned integers, complex numbers, and integer sequence - like fundamental numeric data types. | +| [`include/cute/algorithm`](../../../include/cute/algorithm) | Implementations of utility algorithms such as copy, fill, and clear that automatically leverage architecture-specific features if available. | +| [`include/cute/arch`](../../../include/cute/arch) | Wrappers for architecture-specific matrix-matrix multiply and copy instructions. | +| [`include/cute/atom`](../../../include/cute/atom) | Meta-information for instructions in `arch` and utilities like partitioning and tiling. + +## Tutorial + +This directory contains a CuTe tutorial in Markdown format. +The file +[`0x_gemm_tutorial.md`](./0x_gemm_tutorial.md) +explains how to implement dense matrix-matrix multiply using CuTe components. +It gives a broad overview of CuTe and thus would be a good place to start. + +Other files in this directory discuss specific parts of CuTe. + +* [`01_layout.md`](./01_layout.md) describes `Layout`, CuTe's core abstraction. + +* [`02_layout_operations.md`](./02_layout_operations.md) describes more advanced `Layout` operations and the CuTe layout algebra. + +* [`03_tensor.md`](./03_tensor.md) describes `Tensor`, + a multidimensional array abstraction which composes `Layout` + with an array of data. + +* [`04_algorithms.md`](./04_algorithms.md) summarizes CuTe's + generic algorithms that operate on `Tensor`s. + +* [`0t_mma_atom.md`](./0t_mma_atom.md) demonstrates CuTe's meta-information and interface to our GPUs' + architecture-specific Matrix Multiply-Accumulate (MMA) instructions. + +* [`0x_gemm_tutorial.md`](./0x_gemm_tutorial.md) provides a walkthrough of building a GEMM from scratch using CuTe. + +* [`0y_predication.md`](./0y_predication.md) explains what to do + if a tiling doesn't fit evenly into a matrix. diff --git a/media/docs/cute/01_layout.md b/media/docs/cute/01_layout.md new file mode 100644 index 0000000000..882d541ab3 --- /dev/null +++ b/media/docs/cute/01_layout.md @@ -0,0 +1,254 @@ +# CuTe Layouts + +## Layout + +This document describes `Layout`, CuTe's core abstraction. +A `Layout` maps from (a) logical coordinate space(s) +to a physical index space. + +`Layout`s present a common interface to multidimensional array access +that abstracts away the details of how the array's elements are organized in memory. +This lets users write algorithms that access multidimensional arrays generically, +so that layouts can change, without users' code needing to change. + +CuTe also provides an "algebra of `Layout`s." +`Layout`s can be combined and manipulated +to construct more complicated layouts +and to partition them across other layouts. +This can help users do things like partition layouts of data over layouts of threads. + +## Layouts and Tensors + +Any of the `Layout`s discussed in this section can be composed with data -- a pointer or an array -- to create a `Tensor`. The responsibility of the `Layout` is to define valid coordinate space(s) and, therefore, the logical shape of the data and map those into an index space. The index space is precisely the offset that would be used to index into the array of data. + +For details on `Tensor`, please refer to the +[`Tensor` section of the tutorial](./03_tensor.md). + +## Shapes and Strides + +A `Layout` is a pair of `Shape` and `Stride`. +Both `Shape` and `Stride` are `IntTuple` types. + +### IntTuple + +An `IntTuple` is an integer or a tuple of `IntTuple`s. +This means that `IntTuple`s can be arbitrarily nested. +Operations defined on `IntTuple`s include the following. + +* `get(IntTuple)`: The `I`th element of the `IntTuple`. Note that `get<0>` is defined for integer `IntTuples`. + +* `rank(IntTuple)`: The number of elements in an `IntTuple`. An int has rank 1, a tuple has rank `tuple_size`. + +* `depth(IntTuple)`: The number of hierarchical `IntTuple`s. An int has depth 0, a tuple has depth 1, a tuple that contains a tuple has depth 2, etc. + +* `size(IntTuple)`: The product of all elements of the IntTuple. + +We write `IntTuple`s with parenthesis to denote the hierarchy. E.g. `6`, `(2)`, `(4,3)`, `(3,(6,2),8)` are all `IntTuple`s. + +## Layout + +A `Layout` is then a pair of `IntTuple`s. The first defines the abstract *shape* of the layout and the second defines the *strides*, which map from coordinates within the shape to the index space. + +As a pair of `IntTuple`s, we can define many similar operations on `Layout`s including + +* `get(Layout)`: The `I`th sub-layout of the `Layout`. + +* `rank(Layout)`: The number of modes in a `Layout`. + +* `depth(Layout)`: The number of hierarchical `Layout`s. An int has depth 0, a tuple has depth 1, a tuple that contains a tuple has depth 2, etc. + +* `shape(Layout)`: The shape of the `Layout`. + +* `stride(Layout)`: The stride of the `Layout`. + +* `size(Layout)`: The logical extent of the `Layout`. Equivalent to `size(shape(Layout))`. + +### Hierarchical access functions + +`IntTuple`s and thus `Layout`s can be arbitrarily nested. +For convenience, we define versions of some of the above functions +that take a sequence of integers, instead of just one integer. +This makes it possible to access elements +inside of nested `IntTuple` or `Layout`. +For example, we permit `get(x)`, where `I...` here +and throughout this section is a "C++ parameter pack" +that denotes zero or more (integer) template arguments. +That is, `get(x)` is equivalent to +`get(` $\dots$ `(get(get(x)))` $\dots$ `))`, +where the ellipses are pseudocode and not actual C++ syntax. +These hierarchical access functions include the following. + +* `rank(x) := rank(get(x))`. The rank of the `I...`th element of `x`. + +* `depth(x) := depth(get(x))`. The depth of the `I...`th element of `x`. + +* `size(x) := size(get(x))`. The size of the `I...`th element of `x`. + +### Vector examples + +Then, we can define a vector as any `Shape` and `Stride` pair with `rank == 1`. +For example, the `Layout` + +``` +Shape: (8) +Stride: (1) +``` + +defines a contiguous 8-element vector. +Similarly, with a stride of `(2)`, +the interpretation is that the eight elements +are stored at positions 0, 2, 4, $\dots$. + +By the above definition, we *also* interpret + +``` +Shape: ((4,2)) +Stride: ((1,4)) +``` + +as a vector, since its shape is rank 1. The inner shape describes a 4x2 layout of data in column-major order, but the extra pair of parenthesis suggest we can interpret those two modes as a single 1-D 8-element vector instead. Due to the strides, the elements are also contiguous. + +### Matrix examples + +Generalizing, we define a matrix as any `Shape` and `Stride` pair with rank 2. For example, + +``` +Shape: (4,2) +Stride: (1,4) + 0 4 + 1 5 + 2 6 + 3 7 +``` + +is a 4x2 column-major matrix, and + +``` +Shape: (4,2) +Stride: (2,1) + 0 1 + 2 3 + 4 5 + 6 7 +``` + +is a 4x2 row-major matrix. + +Each of the modes of the matrix can also be split into *multi-indices* like the vector example. +This lets us express more layouts beyond just row major and column major. For example, + +``` +Shape: ((2,2),2) +Stride: ((4,1),2) + 0 2 + 4 6 + 1 3 + 5 7 +``` + +is also logically 4x2, with a stride of 2 across the rows but a multi-stride down the columns. +Since this layout is logically 4x2, +like the column-major and row-major examples above, +we can _still_ use 2-D coordinates to index into it. + +## Constructing a `Layout` + +A `Layout` can be constructed in many different ways. +It can include any combination of compile-time (static) integers +or run-time (dynamic) integers. + +```c++ +auto layout_8s = make_layout(Int<8>{}); +auto layout_8d = make_layout(8); + +auto layout_2sx4s = make_layout(make_shape(Int<2>{},Int<4>{})); +auto layout_2sx4d = make_layout(make_shape(Int<2>{},4)); + +auto layout_2x4 = make_layout(make_shape (2, make_shape (2,2)), + make_stride(4, make_stride(1,2))); +``` + +## Using a `Layout` + +The fundamental use of a `Layout` is to map between logical coordinate space(s) and index space. For example, to print an arbitrary rank-2 layout, we can write the function + +```c++ +template +void print2D(Layout const& layout) +{ + for (int m = 0; m < size<0>(layout); ++m) { + for (int n = 0; n < size<1>(layout); ++n) { + printf("%3d ", layout(m,n)); + } + printf("\n"); + } +} +``` + +which produces the following output for the above examples. + +``` +> print2D(layout_2sx4s) + 0 2 4 6 + 1 3 5 7 +> print2D(layout_2sx4d) + 0 2 4 6 + 1 3 5 7 +> print2D(layout_2x4) + 0 2 1 3 + 4 6 5 7 +``` + +The multi-indices within the `layout_4x4` example are handled as expected and interpreted as a rank-2 layout. + +Note that for `layout_1x4`, we're using a 1-D coordinate for a 2-D multi-index in the second mode. In fact, we can generalize this and treat all of the above layouts as 1-D layouts. For instance, the following `print1D` function + +```c++ +template +void print1D(Layout const& layout) +{ + for (int i = 0; i < size(layout); ++i) { + printf("%3d ", layout(i)); + } +} +``` + +produces the following output for the above examples. + +``` +> print1D(layout_8s) + 0 1 2 3 4 5 6 7 +> print1D(layout_8d) + 0 1 2 3 4 5 6 7 +> print1D(layout_2sx4s) + 0 1 2 3 4 5 6 7 +> print1D(layout_2sx4d) + 0 1 2 3 4 5 6 7 +> print1D(layout_2x4) + 0 4 2 6 1 5 3 7 +``` + +This shows explicitly that all of the layouts are simply folded views of an 8-element array. + +## Summary + +* The `Shape` of a `Layout` defines its coordinate space(s). + + * Every `Layout` has a 1-D coordinate space. + This can be used to iterate in a "generalized-column-major" order. + + * Every `Layout` has a R-D coordinate space, + where R is the rank of the layout. + These spaces are ordered _colexicographically_ + (reading right to left, instead of "lexicographically," + which reads left to right). + The enumeration of that order + corresponds to the 1-D coordinates above. + + * Every `Layout` has an h-D coordinate space where h is "hierarchical." These are ordered colexicographically and the enumeration of that order corresponds to the 1-D coordinates above. An h-D coordinate is congruent to the `Shape` so that each element of the coordinate has a corresponding element of the `Shape`. + +* The `Stride` of a `Layout` maps coordinates to indices. + + * In general, this could be any function from 1-D coordinates (integers) to indices (integers). + + * In `CuTe` we use an inner product of the h-D coordinates with the `Stride` elements. diff --git a/media/docs/cute/02_layout_operations.md b/media/docs/cute/02_layout_operations.md new file mode 100644 index 0000000000..f9c9734a79 --- /dev/null +++ b/media/docs/cute/02_layout_operations.md @@ -0,0 +1,710 @@ +# CuTe Layout Operations + +CuTe provides an "algebra of `Layout`s." +`Layout`s can be combined and manipulated +to construct more complicated `Layout`s. +This includes tiling and partitioning `Layout`s across other `Layout`s. +In this section, we explain some of these core operations in detail. + +## How do I print CuTe objects on host or device? + +CuTe comes with different ways to print CuTe objects. +You can print human-readable text, +or you can print LaTeX commands for generating +a beautifully formatted and colored table +describing the CuTe object. +Both of these can be helpful for reasoning about or debugging +layouts, copy atoms, or matrix multiply atoms +(don't worry, we'll explain all of these things in this tutorial). + +CuTe's print functions work on either host or device. +Note that on device, printing is expensive. +Even just leaving print code in place on device, +even if it is never called +(e.g., printing in an `if` branch that is not taken at run time), +may generate slower code. +Thus, be sure to remove code that prints on device after debugging. + +The following code examples assume that you have a +`using namespace cute;` statement in scope. + +### Printing human-readable text + +The `cute::print` function has overloads for almost all CuTe types, including Pointers, Layout, Shape, Stride, and Tensors. When in doubt, try calling `print` on it. You might also only want to print on thread 0 of each thread block, or block 0 of the grid. The `thread0()` function returns true only for global thread 0 of the kernel. A typical idiom for printing CuTe objects to print only on thread 0 of block 0. + +```c++ +if (thread0()) { + print(some_cute_object); +} +``` + +Some algorithms do different things on different threads or blocks, +so you might sometimes need to print on threads or blocks other than zero. +The header file +[`cute/util/debug.hpp`](../../../include/cute/util/debug.hpp), +among other utilities, +includes the function `bool thread(int tid, int bid)` +that returns `true` if running on thread `tid` and block `bid`. + +Some CuTe types have special printing functions that use a different output format. +For example, `print_layout` can display a rank-2 layout in a table +(using plain text formatting). +It has an overload taking a rank-2 matrix layout and a thread layout, +that displays a table with the mapping between threads and values. + +Some CuTe types might not have overloads for `print`, +but there are other ways to print their contents. +For example, copy atoms and mma atoms +(see elsewhere in this tutorial) +have a `print_all()` member function. + +### Printing LaTeX output + +The `cute::print_latex` function works like `cute::print`, +but prints LaTeX commands that you can use +to generate a nicely formatted and colored table. + +## Fundamental types + +### Layout and its components + +This directory includes +[an overview of CuTe's fundamental types for describing layouts](./01_layout.md). + +#### Tuple + +CuTe starts with a Tuple, which is a finite ordered list of zero or more elements. +In C++, we identify a Tuple with the +[`cute::tuple` class](../../../include/cute/container/tuple.hpp). +`cute::tuple` behaves like `std::tuple`, but it works on device or host, +and it imposes restrictions on its template arguments for performance and simplicity. + +#### IntTuple + +CuTe then defines an IntTuple as either an integer, or a Tuple of IntTuple. +This recursive definition lets us build arbitrarily nested layouts. +In C++, we identify an IntTuple with [`IntTuple`](../../../include/cute/int_tuple.hpp), +which is just an alias of `cute::tuple`. +Any of the following are thus valid template arguments of IntTuple. + +1. "Run-time integers" (or "static integers") + are just ordinary integral types like `int` or `size_t`. + +2. "Compile-time integers" include `std::integral_constant` + or subclasses of it that CuTe defines, + such as `Int` (see below). + These types all have in common + that the value is encoded in the type itself + (as a public `static constexpr value` member). + CuTe defines aliases `_1`, `_2`, `_3` etc. + to the types `Int<1>`, `Int<2>`, `Int<3>` etc. + +3. `IntTuple` with any valid template arguments. + +CuTe reuses IntTuple for many different things, +including Shape, Stride, Step, and Coord +(see [`include/cute/layout.hpp`](../../../include/cute/layout.hpp)). +In C++, Shape, Stride, Step, and Coord are all aliases for IntTuple. + +### Layout + +A Layout is a tuple of (Shape, Stride). +Semantically, it implements a mapping from +a "logical" Shape-shaped (multidimensional) index, +to a "physical" 1-D index into an array. +Here is an example of a 2 x 3 array with static strides (3, 1). + +```c++ +Layout layout = make_layout(make_shape (_2{}, _3{}), + make_stride(_3{}, _1{})); +print_layout(layout); +for (int i = 0; i < size(layout); ++i) { + print(layout(i)); + print(", "); +} +print("\n"); +print(layout(1, 1)); +print("\n"); +``` + +This code produces the following text output. + +```text +(_2,_3):(_3,_1) + 0 1 2 + +---+---+---+ + 0 | 0 | 1 | 2 | + +---+---+---+ + 1 | 3 | 4 | 5 | + +---+---+---+ +0, 3, 1, 4, 2, 5, +4 +``` + +`print(layout(1, 1))` prints the mapping of +the logical 2-D coordinate (0,1) to 1-D index, which is 4. +You can see that from the table, +which shows the left logical index as the "row," +and the right logical index as the "column." + +### Underscore (`_`) + +An Underscore is a special type used for array slices. The underscore punctuation `_` is a constant instance of Underscore. It acts like `:` (the colon punctuation) in Python or Fortran array slices. See [`include/cute/underscore.hpp`](../../../include/cute/underscore.hpp). + +### Tile + +"A Tile is not a Layout, it's a tuple of Layouts or Tiles or Underscores." +See [`include/cute/tile.hpp`](../../../include/cute/tile.hpp). + +The algebraic layout operations discussed below are defined on `Layout`s, but `Tile` allows these operations to recurse and to be applied to sublayouts or particular modes of a given Layout. These are referred to as by-mode operations. + +See the section on "Logical Divide" to see an example of using `Tile` to extract portions of a row-mode and portions of a column-mode independently. + +## Layout definitions and operations + +### Layouts are functions from integers (logical 1-D coordinate) to integers (1-D index) + +The `for` loop in the above print example shows how CuTe identifies 1-D coordinates with a column-major layout of logical 2-D coordinates. Iterating from `i = 0` to `size(layout)` (which is 6), and indexing into our layout with the single integer coordinate `i`, traverses the layout in column-major fashion, even though this is a row-major layout. You can see this from the output of the `for` loop (0, 3, 1, 4, 2, 5). CuTe calls this index `i` a "1-D coordinate," versus the "natural coordinate," which would be the logical 2-D coordinate. + +If you're familiar with the C++23 feature `mdspan`, +this is an important difference between +`mdspan` layout mappings and CuTe `Layout`s. +`mdspan` layout mappings are *one way*: +they always take a multidimensional logical coordinate, +and they return an integer offset. +Depending on the strides, +the offset may skip over elements of the physical 1-D array. +Thus, `mdspan`'s offset does NOT mean the same thing as +the 1-D logical coordinate `i` in the `for` loop above. +You can iterate correctly over any CuTe `Layout` +by using the 1-D logical coordinate. +`mdspan` doesn't have an idea of a 1-D logical coordinate. + +### Rank, depth, size, cosize + +*Rank*: the tuple size of the layout's shape. + +*Depth*: the depth of the layout's shape. A single integer has depth 0. A tuple has depth 1 + the max depth of its components. + +*Size*: Size of the shape; size of the domain of the function. This is the product of all extents in the layout's shape. + +*Cosize*: Size of the function's codomain (not necessarily the range); for a layout A, A(size(A) - 1) + 1. (Here, we use size(A) - 1 as a 1-D logical coordinate input.) + +### Layout compatibility + +We say that layouts A and B are *compatible* if their shapes are compatible. Shape A is compatible with shape B if any natural coordinate of A is also a valid coordinate for B. + +### Flatten + +The `flatten` operation "un-nests" a potentially nested Layout. For example, + +```c++ +Layout layout = Layout, _1>, + Stride, _0>>{}; +Layout flat_layout = flatten(layout); +``` + +results in `flat_layout` having the following type + +```text +Layout, Stride<_3, _1, _0>> +``` + +and + +```c++ +Layout layout = Layout>, + Stride<_4, Stride<_1, _16>>>{}; +Layout flat_layout = flatten(layout); +``` + +results in `flat_layout` having the following type + +```text +Layout, Stride<_4, _1, _16>> +``` + +Hierarchical Layouts and flattening let us reinterpret tensors in place as matrices, matrices as vectors, vectors as matrices, etc. This lets us implement arbitrary tensor contractions as batched matrix multiply, by combining the contraction modes into a single mode, and combining the A, B, C, and "batch" modes as needed to reach the desired form. + +### Coalesce + +The `coalesce` operation first flattens the layout, then combines all the modes that are possible to combine, starting with mode 0 (the leftmost mode) and moving right. If all the modes can be combined, then this results in a 1-D layout expressing what array elements the original layout accesses. + +For example, + +```text +layout: (_2,(_1,_6)):(_1,(_6,_2)) +coalesce(layout): _12:_1 +``` + +What does it mean to "combine" modes? In the above example, the flattened layout is (2, 1, 6) : (1, 6, 2). + +1. If we look at the leftmost two modes, this is just a vector of length 2 and stride 1. The middle mode has extent 1, so the corresponding stride 6 would not be observed anyway. This leaves us with (2, 6) : (1, 2). + +2. The intermediate result (2, 6) : (1, 2) is just a 2 x 6 column-major matrix, which can be coalesced into a vector of length 12 and stride 1. + +More formally, "combining all the modes" means a left fold, where the binary operation that combines two modes has three cases. + +1. If the leftmost layout is s1:d1, and the next layout is 1:d0, then combine into s1:d1. This generalizes Step 1 above. If a mode has extent 1, we can't observe its stride, so we can skip the mode. + +2. If the leftmost layout is 1:d1, and the next layout is s0:d0, then combine into s0:d0. Again, if a mode has extent 1, we can't observe its stride, so we can skip the mode. + +3. If the leftmost layout is s1:d1, and the next layout is s0 : s1*d1, then combine into s0 * s1 : d1. This generalizes Step 2 above. One can call this "noticing a column-major layout sequence." + +That's it! For example, the result of coalescing the row-major layout (2, 2) : (2, 1) is (2, 2) : (2, 1), the same layout, because none of the above three cases applies. + +### Complement + +#### Definition + +The complement B of a layout A with respect to an integer M satisfies the following properties. + +1. $A$ and $B$ are *disjoint*: $A(x) \neq B(x)$ for all $x \neq 0$ in the domain of $A$. + +2. B is *ordered*: $`B(x-1) < B(x)`$ for all $x$ in $\{0, 1, \dots, size(B) - 1\}$. + +3. B is *bounded* by M: $size(B) \geq M / size(A)$, and $cosize(B) \leq floor(M / cosize(A)) * cosize(A)$. + +Regarding disjointness: we need to specify $x \neq 0$ because CuTe layouts are linear. That is, if the domain is nonempty, the range always contains zero. + +Regarding the ordered property: CuTe layouts are hierarchically strided, so this implies that if size(B) is nonzero, then the strides of B are all positive. + +#### Examples + +complement(4:1, 24) is 6:4. + +1. The result is disjoint of 4:1, so it must have a stride of at least 4 (since it includes 0, but must skip over 1, 2, 3). + +2. The size of the result is $\geq 24 / 4 = 6$. (This plus Step (1) means that the cosize is at least 24.) + +3. The cosize of the result is $\leq (24 / 4) * 4 = 24$. (This plus Step (2) means that the cosize is exactly 24.) + +4. The only (1-D) layout with size 6 and cosize 24 is 6:4. + +complement(6:4, 24) is 4:1. + +1. 4:1 is disjoint of 6:4, but so is s:d + for any s > 0 and d > 20. + +2. The size of the result is $\geq 24 / 6 = 4$. + +3. The cosize of the result is $\leq (24 / 21) * 21 = 21$. + +4. The stride cannot be greater than 20 + (else (2) would contradict (3)), + so it must be less than 4. + +5. This leaves 4:1 by elimination. + +### Composition + +Layouts are functions, so composition of layouts is just composition of functions. The composition $A \circ B$ means "apply the layout B first, then treat the result as a 1-D logical coordinate input to the layout A, and apply A to it." Very often, this composition can be represented as another Layout. + +#### Rules for computing composition + +Both humans and CuTe compute composition using the following rules. + +1. $A \circ B$ has a shape that is compatible with B. In function composition, the rightmost function defines the domain. For `Layout`s this means that any valid coordinate for $B$ can also be used as a coordinate for $A \circ B$. + +2. Concatenation: A layout can be expressed as the concatenation of its sublayouts. We denote concatenation with parentheses: $B = (B_0,B_1,...)$. The CuTe function `make_layout`, when given zero or more `Layout`s, concatenates them. + +3. Composition is (left-)distributive with concatenation: $A \circ B = A \circ (B0, B1, ...) = (A \circ B0, A \circ B1, ...)$. + +4. "Base case": For layouts $A = a : b$ and $B = c : d$ with integral shape and stride, $A \circ B = R = c : (b * d)$. + +5. By-mode composition: Let $\langle B, C \rangle$ (angle brackets, not parentheses) + denote a tuple of two layouts B and C, not their concatenation. Let A = (A0, A1). + Then, $A \circ \langle B, C \rangle = (A0, A1) \circ \langle B, C \rangle = (A0 \circ B, A1 \circ C)$. + This allows the application of composition independently to sublayouts of $A$. + +#### Examples: Reshape a vector into a matrix + +This section gives two composition examples. Both start with a vector with layout $20:2$ (that is, the vector has 20 elements, and the stride between each is 2). They compose this vector with a 4 x 5 matrix layout. This effectively "reshapes" the vector in place into a matrix. + +##### Example 1 + +$20:2 \circ (4,5) : (1,4)$. + +This describes interpreting the vector $20:2$ +as a 4 x 5 column-major matrix. + +The resulting layout has shape $(4,5)$, +because in function composition, +the rightmost function defines the domain. +What are the strides? + +1. A layout can be expressed as the concatenation of its sublayouts, + so $(4,5) : (1,4)$ is $(4:1, 5:4)$. + +2. Composition is distributive, so + $20:2 \circ (4:1, 5:4)$ is $(20:2 \circ 4:1, 20:2 \circ 5:4)$. + +3. $20:2 \circ 4:1$ has shape 4 (rightmost function defines the domain) + and stride $2 = 2 \cdot 1$. + +4. $20:2 \circ 5:4$ has shape 5 and stride $8 = 2 \cdot 4$. + +5. Result: (4:2, 5:8), which by concatenation is (4,5) : (2,8). + +#### Example 2 + +$20:2 \circ (4,5) : (5,1)$. + +This describes interpreting the vector 20:2 +as a 4 x 5 row-major matrix. + +The resulting layout has shape $(4,5)$, just as before. What are the strides? + +1. By deconcatenation, $(4,5) : (5,1)$ is $(4:5, 5:1)$. + +2. Composition is distributive, so $20:2 \circ (4:5, 5:1)$ is $(20:2 \circ 4:5, 20:2 \circ 5:1)$. + +3. $20:2 \circ 4:5$ has shape $4$ and stride $10 = 2 \cdot 5$. + +4. $20:2 \circ 5:1$ has shape $5$ and stride $2 = 2 \cdot 1$. + +5. Result: (4:10, 5:2), which by concatenation is (4,5) : (10,2). + +### Product + +CuTe includes four different kinds of layout products. + +1. `logical_product` + +2. `blocked_product` + +3. `raked_product` + +4. `tiled_product` + +`logical_product(A, B)` results in a layout where each element of layout B +has been replaced by a "copy" of layout A. +The other three products offer variations of this idea. + +#### Example: Tiled matrix + +Suppose that I want to make a matrix consisting of 3 x 4 tiles +in a row-major arrangement, +where each tile is a 2 x 2 column-major matrix. + +The Layout of each tile (tile) has Shape (2,2) and Stride (1,2). + +The Layout of the "matrix of tiles" (`matrix_of_tiles`) +has Shape (3,4) and Stride (4,1). + +##### Blocked product: the intuitive tiling + +If I were to deduce by hand what the layout of the tiled matrix should be, +it would look like this. + +| | (0,0) | (1,0) | (0,1) | (1,1) | (0,2) | (1,2) | (0,3) | (1,3) | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | +| (0,0) | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | +| (1,0) | 1 | 3 | 5 | 7 | 9 | 11 | 13 | 15 | +| (0,1) | 16 | 18 | 20 | 22 | 24 | 26 | 28 | 30 | +| (1,1) | 17 | 19 | 21 | 23 | 25 | 27 | 29 | 31 | +| (0,2) | 32 | 34 | 36 | 38 | 40 | 42 | 44 | 46 | +| (1,2) | 33 | 35 | 37 | 39 | 41 | 43 | 45 | 47 | + +The row and column labels use the equivalence of 1-D logical coordinates and 2-D column-major coordinates. The left index in each pair is the row resp. column coordinate of the tile, while the right index in each pair is the row resp. column coordinate of the matrix-of-tiles. The resulting layout has Shape ((2, 3), (2, 4)), and Stride ((1, 16), (2, 4)), and the second mode can be coalesced. The Shape ((2, 3), (2, 4)) is hierarchical, but it is still rank-2 and can be drawn in 2D as above. Note how the row mode of the tile remains part of the row mode of the product, and the column mode of the tile remains a column mode of the product. + +The above layout is what `blocked_product(tile, matrix_of_tiles)` produces. +A critical use case for blocked product is "tiling" an "atom" +(some tile that relates to a hardware feature) over a matrix. + +```c++ +Layout tile = Layout, + Stride<_1,_2>>{}; +Layout matrix_of_tiles = Layout, + Stride<_4,_1>>{}; + +print_layout(blocked_product(tile, matrix_of_tiles)); +``` + +##### Logical product + +The logical product `logical_product(tile, matrix_of_tiles)` +results in Shape ((2, 2), (3, 4)) and Stride ((1, 2), (16, 4)). + +| | (0,0) | (1,0) | (2,0) | (0,1) | (1,1) | (2,1) | (0,2) | (1,2) | (2,2) | (0,3) | (1,3) | (2,3) | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| (0,0) | 0 | 16 | 32 | 4 | 20 | 36 | 8 | 24 | 40 | 12 | 28 | 44 | +| (1,0) | 1 | 17 | 33 | 5 | 21 | 37 | 9 | 25 | 41 | 13 | 29 | 45 | +| (0,1) | 2 | 18 | 34 | 6 | 22 | 38 | 10 | 26 | 42 | 14 | 30 | 46 | +| (1,1) | 3 | 19 | 35 | 7 | 23 | 39 | 11 | 27 | 43 | 15 | 31 | 47 | + +Note how the tile appears in the leftmost column and is reproduced +in each column in the same order as the matrix-of-tiles. That is, +the tile can be indexed through the first mode of the result and the +matrix-of-tiles can be indexed through the second mode. + +```c++ +Layout tile = Layout, + Stride<_1,_2>>{}; +Layout matrix_of_tiles = Layout, + Stride<_4,_1>>{}; + +print_layout(logical_product(tile, matrix_of_tiles)); +``` + +##### Raked product + +The raked product `raked_product(tile, matrix_of_tiles)` results in +Shape ((3, 2), (4, 2)) and Stride ((16, 1), (4, 2)). + +| | (0,0) | (1,0) | (2,0) | (3,0) | (0,1) | (1,1) | (2,1) | (3,1) | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | +| (0,0) | 0 | 4 | 8 | 12 | 2 | 6 | 10 | 14 | +| (1,0) | 16 | 20 | 24 | 28 | 18 | 22 | 26 | 30 | +| (2,0) | 32 | 36 | 40 | 44 | 34 | 38 | 42 | 46 | +| (0,1) | 1 | 5 | 9 | 13 | 3 | 7 | 11 | 15 | +| (1,1) | 17 | 21 | 25 | 29 | 19 | 23 | 27 | 31 | +| (2,1) | 33 | 37 | 41 | 45 | 35 | 39 | 43 | 47 | + +The tile is now interleaved or "raked" with the other 3x4 matrix-of-tiles +instead of appearing as blocks. Other references call this is cyclic +distribution. + +This might look familiar if you have ever used ScaLAPACK. +It expresses a 2-D block cyclic distribution of a 6 x 8 matrix +over 4 processes in a 2 x 2 "process grid." See +["The Two-dimensional Block-Cyclic Distribution"](https://netlib.org/scalapack/slug/node75.html#sec2dbcd) +and +["Local Storage Scheme and Block-Cyclic Mapping"](https://netlib.org/scalapack/slug/node76.html#seclocalstorage) +in the ScaLAPACK Users' Guide. + +In general, `logical_product` and these variations can produce any interleaving, +including blocked, cyclic, by-mode blocked/cyclic, and intermediate interleavings +that don't have common names. + +```c++ +Layout tile = Layout, + Stride<_1,_2>>{}; +Layout matrix_of_tiles = Layout, + Stride<_4,_1>>{}; + +print_layout(raked_product(tile, matrix_of_tiles)); +``` + +### Division + +The previous section covered layout products, +that reproduce one layout over another. +This section covers layout *division*. +Functions that divide a layout into components are useful +as a basis for tiling and partitioning layouts. + +For example, consider folding a vector into a matrix. +We could imagine an operation, called `logical_divide`, + +```c++ +Layout vec = Layout<_16,_3>{}; // 16 : 3 +Layout col = Layout< _4,_1>{}; // 4 : 1 +Layout mat = logical_divide(vec, col); // (4,4) : (3,12) +``` + +that "takes" the first 4 elements of the vector into the first mode +and leaves the "rest" in the second mode. This is a column-major matrix +view of the data in `vec`. +What if we want a row-major matrix view? + +```c++ +Layout vec = Layout<_16,_3>{}; // 16 : 3 +Layout col = Layout< _4,_4>{}; // 4 : 4 +Layout mat = logical_divide(vec, col); // (4,4) : (12,3) +``` + +Now, every fourth element of the vector is in the first mode and +the "rest" are in the second mode. +Multidimensional, hierarchical indices let us extend this operation +to any layout that "divides" the vector. + +```c++ +Layout vec = Layout<_16,_3>{}; // 16 : 3 +Layout col = Layout< _4,_2>{}; // 4 : 2 +Layout mat = logical_divide(vec, col); // (4,(2,2)) : (6,(3,24)) +``` + +```c++ +Layout vec = Layout<_16,_3>{}; // 16 : 3 +Layout col = Layout< _4,_2>{}; +Layout col = Layout, + Stride<_4,_1>>{}; // (2,2) : (4,1) +Layout mat = logical_divide(vec, col); // ((2,2),(2,2)) : ((12,3),(6,24)) +``` + +All of the above examples produce a 4x4 matrix +that can be indexed and treated like a normal 4x4 matrix, +but each has a different underlying layout. +Thus, our algorithms can be written using logical coordinates, +without needing to address the detailed indexing that each layout requires. + +CuTe includes 3 different kinds of layout division operations. + +1. `logical_divide` + +2. `zipped_divide` + +3. `tiled_divide` + +We will summarize these in the sections that follow. + +#### Logical divide : the intuitive tiling + +Suppose I have the 6 x 8 matrix from the Raked Product section +and want to "collect" the `tile`, turning the Raked Product into +the Blocked Product. + +To do this, we would like to gather two elements from the column +and leave the rest, then gather two elements from the row and leave the rest. +Thus, we want to apply `logical_divide` independently to the rows and cols +in order to retrieve the appropriate elements. + +In code, we copy the Layout from the result of the Raked Product section, then +specify the elements in the rows and cols we would like to gather. + +```c++ +Layout raked_prod = Layout,Shape <_4,_2>>, + Stride,Stride<_4,_2>>>{}; +Tile subtile = make_tile(Layout<_2,_3>{}, // Gather elements 2 : 3 from mode 0 + Layout<_2,_4>{}); // Gather elements 2 : 4 from mode 1 + +print_layout(logical_divide(raked_prod, subtile)); +``` + +Indeed, this does produce the result from the Blocked Product section. + +| | (0,0) | (1,0) | (0,1) | (1,1) | (0,2) | (1,2) | (0,3) | (1,3) | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | +| (0,0) | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | +| (1,0) | 1 | 3 | 5 | 7 | 9 | 11 | 13 | 15 | +| (0,1) | 16 | 18 | 20 | 22 | 24 | 26 | 28 | 30 | +| (1,1) | 17 | 19 | 21 | 23 | 25 | 27 | 29 | 31 | +| (0,2) | 32 | 34 | 36 | 38 | 40 | 42 | 44 | 46 | +| (1,2) | 33 | 35 | 37 | 39 | 41 | 43 | 45 | 47 | + +Of course, any other rearrangement of the rows and cols is also valid. + +#### Zipped divide + +The `zipped_divide` function applies `logical_divide`, and then gathers the +"subtiles" into a single mode and the "rest" into a single mode. + +For example, if we apply `zipped_divide` instead of `logical_divide` in the example above, + +```c++ +Layout raked_prod = Layout,Shape <_4,_2>>, + Stride,Stride<_4,_2>>>{}; +Tile subtile = make_tile(Layout<_2,_3>{}, // Gather elements 2 : 3 from mode 0 + Layout<_2,_4>{}); // Gather elements 2 : 4 from mode 1 + +print_layout(zipped_divide(raked_prod, subtile)); +``` + +then we get the result + +| | (0,0) | (1,0) | (2,0) | (0,1) | (1,1) | (2,1) | (0,2) | (1,2) | (2,2) | (0,3) | (1,3) | (2,3) | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| (0,0) | 0 | 16 | 32 | 4 | 20 | 36 | 8 | 24 | 40 | 12 | 28 | 44 | +| (1,0) | 1 | 17 | 33 | 5 | 21 | 37 | 9 | 25 | 41 | 13 | 29 | 45 | +| (0,1) | 2 | 18 | 34 | 6 | 22 | 38 | 10 | 26 | 42 | 14 | 30 | 46 | +| (1,1) | 3 | 19 | 35 | 7 | 23 | 39 | 11 | 27 | 43 | 15 | 31 | 47 | + +Note that this is the same layout as the result in the Logical Product section. +That is, the first mode is our original tile (and can be interpreted as a 2x2 matrix itself) +and the second mode is its logical layout within the raked layout. + +##### More Examples of Divide + +For brevity, shapes can be used with `logical_divide` and `tiled_divide` to quickly split and tile modes of a tensor. For example, this C++ code + +```c++ +Layout layout = Layout, + Stride< _1,_128,_0>>{}; +Shape tile_shape = make_shape(_4{},_8{}); +Layout logical_divided_tile = logical_divide(layout, tile_shape); +Layout zipped_divided_tile = zipped_divide(layout, tile_shape); + +print("layout : "); print(layout); print("\n"); +print("tile_shape : "); print(tile_shape); print("\n"); +print("logical_divided_tile : "); print(logical_divided_tile); print("\n"); +print("zipped_divided_tile : "); print(zipped_divided_tile); print("\n\n"); +``` + +produces the following output when we vary `layout`. + +```text +full_layout : (_12,_32,_6):(_1,_128,_0) +tile_shape : (_4,_8) +logical_divided_tile : ((_4,_3),(_8,_4),_6):((_1,_4),(_128,_1024),_0) +zipped_divided_tile : ((_4,_8),(_3,_4,_6)):((_1,_128),(_4,_1024,_0)) + +full_layout : (_12,(_4,_8),_6):(_1,(_32,_512),_0) +tile_shape : (_4,_8) +logical_divided_tile : ((_4,_3),((_4,_2),_4),_6):((_1,_4),((_32,_512),_1024),_0) +zipped_divided_tile : ((_4,(_4,_2)),(_3,_4,_6)):((_1,(_32,_512)),(_4,_1024,_0)) +``` + +This code + +```c++ +Layout layout = make_layout(Shape<_8,_8>{}, + Stride<_8,_1>{}); +Layout tile = make_tile(make_layout(Shape<_4>{}), + make_layout(Shape<_2>{})); +print("layout: "); +print_layout(layout); +print("\n"); +print("tile: "); +print(tile); +print("\n"); +print("logical_divide: "); +print_layout(logical_divide(layout, tile)); +print("zipped_divide: "); +print_layout(zipped_divide(layout, tile)); +``` + +results in the following layouts. + +

+ logical_divide-and-zipped_divide +

+ +This code + +```c++ +Layout layout = make_layout(Shape<_8,_8>{}, + Stride<_8,_1>{}); +Layout tile = make_tile(make_layout(Shape<_2>{}), + make_layout(Shape<_4>{})); +print("layout: "); +print_layout(layout); +print("\n"); +print("tile: "); +print(tile); +print("\n"); +print("logical_divide: "); +print_layout(logical_divide(layout, tile)); +print("zipped_divide: "); +print_layout(zipped_divide(layout, tile)); +``` + +results in the following layouts. + +

+ logical_divide-and-zipped_divide-2 +

+ +#### Tiled divide + +The `tiled_divide` function works like `zipped_divide`, +except that it unpacks the second mode. This is useful when you have a `Tile` that describes all of the elements for a particular operation, for example, and want to gather those together but retain the logical shape of those tiles within the original layout. That is, + +```text +Layout Shape : (M, N, L, ...) +Tile Shape : +Tiled Result : ((M', N'), m, n, L, ...) +``` + +where `m` is `M / M'` and `n` is `N / N'`. +We can consider `m` as the "number of `Tile`s in `M`" and `n` as the "number of `Tile`s in `N`". This style of operation is common when applying MMA Atoms and Copy Atoms. diff --git a/media/docs/cute/03_tensor.md b/media/docs/cute/03_tensor.md new file mode 100644 index 0000000000..2382d834f7 --- /dev/null +++ b/media/docs/cute/03_tensor.md @@ -0,0 +1,262 @@ +# CuTe Tensors + +## A Tensor is a multidimensional array + +CuTe's `Tensor` class represents a multidimensional array. +The array's elements can live in any kind of memory, +including global memory, shared memory, and register memory. + +### Array access + +Users access a `Tensor`'s elements in one of three ways: + +* `operator()`, taking as many integral arguments as the number of modes, + corresponding to the element's (possibly) multidimensional logical index; + +* `operator()`, taking a `Coord` (an `IntTuple` of the logical indices); or + +* `operator[]`, taking a `Coord` (an `IntTuple` of the logical indices). + +### Slices: Get a Tensor accessing a subset of elements + +Users can get a "slice" of a `Tensor`, +that is, a `Tensor` that accesses a subset of elements +of the original `Tensor`. + +Slices happen through the same `operator()` +that they use for accessing an individual element. +Passing in `_` (the underscore character, an instance of `Underscore`) +has the same effect as `:` (the colon character) in Fortran or Matlab: +the resulting slice accesses all indices in that mode of the `Tensor`. + +### Tensor's behavior determined by its Layout and Engine + +A `Tensor`'s behavior is entirely determined by its two components, +which correspond to its two template parameters: `Engine`, and `Layout`. + +For a description of `Layout`, +please refer to [the `Layout` section](./01_layout.md) +of this tutorial, or the [GEMM overview](./0x_gemm_tutorial.md). + +An `Engine` represents a one-dimensional array of elements. +When users perform array access on a `Tensor`, +the `Tensor` uses its `Layout` to map from a logical coordinate +to a one-dimensional index. +Then, the `Tensor` uses its `Engine` +to map the one-dimensional index +to a reference to the element. +You can see this in `Tensor`'s implementation of array access. + +```c++ +decltype(auto) operator[](Coord const& coord) { + return engine().begin()[layout()(coord)]; +} +``` + +One could summarize almost all CuTe use cases as follows: + +* create `Layout`s, + +* create `Tensor`s with those `Layout`s, and + +* invoke (either CuTe's, or custom) algorithms on those `Tensor`s. + +### Ownership of the elements + +`Tensor`s can be owning or nonowning. + +"Owning" `Tensor`s behave like `std::array`. +When you copy the `Tensor`, you (deep-)copy its elements, +and the `Tensor`'s destructor deallocates the array of elements. + +"Nonowning" `Tensor`'s behave like a (raw) pointer to the elements. +Copying the `Tensor` doesn't copy the elements, +and destroying the `Tensor` doesn't deallocate the array of elements. + +Whether a `Tensor` is "owning" or "nonowning" depends entirely on its `Engine`. +This has implications for developers of generic `Tensor` algorithms. +For example, input `Tensor` parameters of a function +should be passed by const reference, +because passing the `Tensor`s by value +might make a deep copy of the `Tensor`'s elements. +It might also *not* make a deep copy of the elements; +there's no way to know without specializing the algorithm +on the `Tensor`'s `Engine` type. +Similarly, output or input/output `Tensor` parameters of a function +should be passed by (nonconst) reference. +Returning a `Tensor` might (or might not) +make a deep copy of the elements. + +The various overloads of the `copy_if` algorithm in +[`include/cute/algorithm/copy.hpp`](../../../include/cute/algorithm/copy.hpp) +take their `src` (input, source of the copy) parameter +as `Tensor& const`, +and take their `dst` (output, destination of the copy) parameter +as `Tensor&`. +Additionally, there are overloads for mutable temporaries like +`Tensor&&` +so that these algorithms can be applied directly to slices, +as in the following example. + +```c++ +copy(src_tensor(_,3), dst_tensor(2,_)); +``` + +In C++ terms, each of the expressions +`src_tensor(_,3)`, and `dst_tensor(2,_)` +is in the "prvalue" +[value category](https://en.cppreference.com/w/cpp/language/value_category), +because it is a function call expression +whose return type is nonreference. +(In this case, calling `Tensor::operator()` +with at least one `_` (`Underscore`) argument +returns a `Tensor`.) +The prvalue `dst_tensor(2,_)` won't match +the `copy` overload taking +`Tensor&`, +because prvalues can't be bound to +nonconst lvalue references (single `&`). +However, it will match the `copy` overload taking +`Tensor&&` +(note the two `&&` instead of one `&`). +Calling the latter overload binds the reference +to the prvalue `dst_tensor(2,_)`. +This results in +[creation of a temporary](https://en.cppreference.com/w/cpp/language/implicit_conversion#Temporary_materialization) +`Tensor` result to be passed into `copy`. + +### CuTe's provided `Engine` types + +CuTe comes with three `Engine` types. + +* `ArrayEngine`: an owning `Engine`, + representing an array of `N` elements of type `T` + +* `ViewEngine`: a nonowning `Engine`, + where `Iterator` is a random access iterator + (either a pointer to an array, or something that acts like one) + +* `ConstViewEngine`: a nonowning `Engine`, + which is the view-of-const-elements version of `ViewEngine` + +### "Tags" for different kinds of memory + +`ViewEngine` and `ConstViewEngine` wrap pointers to various kinds of memory. +Users can "tag" the memory with its space -- e.g., global or shared -- +by calling `make_gmem_ptr(g)` when `g` is a pointer to global memory, +or `make_smem_ptr(s)` when `s` is a pointer to shared memory. + +Tagging memory makes it possible for CuTe's `Tensor` algorithms +to use the fastest implementation for the specific kind of memory. +It also avoids incorrect memory access. +For example, some kinds of optimized copy operations require +that the source of the copy be in global memory, +and the destination of the copy be in shared memory. +Tagging makes it possible for CuTe to dispatch +to those optimized copy operations where possible. +CuTe does this by specializing `Tensor` algorithms +on the `Tensor`'s `Engine` type. + +### Engine members + +In order for a type to be valid for use as an `Engine`, +it must have the following public members. + +```c++ +using value_type = /* ... the value type ... */; +using iterator = /* ... the iterator type ... */; +iterator begin() /* sometimes const */; +``` + +## Constructing a Tensor + +### Nonowning view of existing memory + +A `Tensor` can be a nonowning view of existing memory. +For this use case, users can create the `Tensor` by calling `make_tensor` +with two arguments: a wrapped pointer to the memory to view, and the `Layout`. +Users wrap the pointer by identifying its memory space: +e.g., global memory (via `make_gmem_ptr`) or shared memory (via `make_smem_ptr`). +`Tensor`s that view existing memory can have either static or dynamic `Layout`s. + +Here are some examples of creating `Tensor`s +that are nonowning views of existing memory. + +```c++ +// Global memory (static or dynamic layouts) +Tensor gmem_8s = make_tensor(make_gmem_ptr(A), Int<8>{}); +Tensor gmem_8d = make_tensor(make_gmem_ptr(A), 8); +Tensor gmem_8sx16d = make_tensor(make_gmem_ptr(A), make_shape(Int<8>{},16)); +Tensor gmem_8dx16s = make_tensor(make_gmem_ptr(A), make_shape ( 8 ,Int<16>{}), + make_stride(Int<16>{},Int< 1>{})); + +// Shared memory (static or dynamic layouts) +Shape smem_shape = make_shape(Int<4>{},Int<8>{}); +__shared__ T smem[decltype(size(smem_shape))::value]; // (static-only allocation) +Tensor smem_4x8_col = make_tensor(make_smem_ptr(&smem[0]), smem_shape); +Tensor smem_4x8_row = make_tensor(make_smem_ptr(&smem[0]), smem_shape, GenRowMajor{}); +``` + +### Owning array of register memory + +A `Tensor` can also be an owning array of register memory. +For this use case, users can create the `Tensor` +by calling `make_tensor(layout)`, +where `T` is the type of each element of the array, +and `layout` is the `Tensor`'s `Layout`. +Owning `Tensor`s must have a static `Layout`, +as CuTe does not perform dynamic memory allocation in `Tensor`s. + +Here are some examples of creating owning `Tensor`s. + +```c++ +// Register memory (static layouts only) +Tensor rmem_4x8_col = make_tensor(make_shape(Int<4>{},Int<8>{})); +Tensor rmem_4x8_row = make_tensor(make_shape(Int<4>{},Int<8>{}), GenRowMajor{}); +Tensor rmem_4x8_mix = make_tensor(make_shape (Int<4>{},Int< 8>{}), + make_stride(Int<2>{},Int<32>{})); +Tensor rmem_8 = make_fragment_like(gmem_8sx16d(_,0)); +``` + +The `make_fragment_like` function makes an owning Tensor of register memory, +with the same shape as its input `Tensor` argument. + +## Tensor use examples + +### Copy rows of a matrix from global memory to registers + +The following example copies rows of a matrix (with any `Layout`) +from global memory to register memory, +then executes some algorithm `do_something` +on the row that lives in register memory. + +```c++ +Tensor gmem = make_tensor(make_gmem_ptr(A), make_shape(Int<8>{}, 16)); +Tensor rmem = make_fragment_like(gmem(_, 0)); +for (int j = 0; j < size<1>(gmem); ++j) { + copy(gmem(_, j), rmem); + do_something(rmem); +} +``` + +This code does not need to know anything the `Layout` of `gmem` +other than that it is rank-2 and that the first mode is a compile-time value. +The following code checks both of those conditions at compile time. + +```c++ +CUTE_STATIC_ASSERT_V(rank(gmem) == Int<2>{}); +CUTE_STATIC_ASSERT_V(is_static(gmem))>{}); +``` + +A `Tensor` encapsulates the data type, data location, +and possibly also the shape and stride of the tensor at compile time. +As a result, `copy` can dispatch, based on the types and Layouts of its arguments, +to use any of various synchronous or asynchronous hardware copy instructions +and can auto-vectorize the copy instructions in many cases as well. +CuTe's `copy` algorithm lives in +[`include/cute/algorithm/copy.hpp`](../../../include/cute/algorithm/copy.hpp). +For more details on the algorithms that CuTe provides, +please refer to [the algorithms section](./04_algorithms.md) +of the tutorial, or the +[CuTe overview in the GEMM tutorial](./0x_gemm_tutorial.md). + diff --git a/media/docs/cute/04_algorithms.md b/media/docs/cute/04_algorithms.md new file mode 100644 index 0000000000..e35b75612d --- /dev/null +++ b/media/docs/cute/04_algorithms.md @@ -0,0 +1,223 @@ +# CuTe Tensor algorithms + +This section summarizes the interfaces and implementations +of common numerical algorithms performed on `Tensor`s. + +The implementation of these algorithms may be found in the +[include/cute/algorithm/](../../../include/cute/algorithm/) +directory. + +## `copy` + +CuTe's `copy` algorithm copies the elements of a source `Tensor` +into the elements of a destination `Tensor`. +The various overloads of `copy` can be found in +[`include/cute/algorithm/copy.hpp`](../../../include/cute/algorithm/copy.hpp). + +### Interface and specialization opportunities + +A `Tensor` encapsulates the data type, data location, +and possibly also the shape and stride of the tensor at compile time. +As a result, `copy` can and does dispatch, +based on the types of its arguments, +to use any of various synchronous or asynchronous hardware copy instructions. + +The `copy` algorithm has two main overloads. +The first just takes the source `Tensor` and the destination `Tensor`. + +```c++ +template +CUTE_HOST_DEVICE +void +copy(Tensor const& src, + Tensor & dst); +``` + +The second takes those two parameters, plus a `Copy_Atom`. + +```c++ +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom const& copy_atom, + Tensor const& src, + Tensor & dst); +``` + +The two-parameter `copy` overload picks a default implementation +based only on the types of the two `Tensor` parameters. +The `Copy_Atom` overload lets callers override that default +by specifying a nondefault `copy` implementation. + +### Parallelism and synchronization depend on parameter types + +Either the default implementation or +the implementation selected by a `Copy_Atom` overload +may use none or all available parallelism, +and may have a variety of synchronization semantics. +The behavior depends on `copy`'s parameter types. +Users are expected to figure this out based on their knowledge +of the architecture on which they are running. +(Developers often write a custom optimized kernel +for each GPU architecture.) + +The `copy` algorithm may be sequential per thread, +or it may be parallel across some collection of threads +(e.g., a block or cluster). + +If `copy` is parallel, +then the collection of participating threads +may need synchronization before any thread in the collection +may assume that the copy operation has completed. +For example, if the participating threads form a thread block, +then users must invoke `__syncthreads()` +or the Cooperative Groups equivalent +before they may use the results of `copy`. + +The `copy` algorithm may use asynchronous copy instructions, +such as `cp.async`, or its C++ interface `memcpy_async`. +In that case, users will need to perform +the additional synchronization appropriate to that underlying implementation +before they may use the results of the `copy` algorithm. +[The CuTe GEMM tutorial example](../../../examples/cute/tutorial/sgemm_nt_1.cu) +shows one such synchronization method. +More optimized GEMM implementations use pipelining techniques +to overlap asynchronous `copy` operations with other useful work. + +### A generic copy implementation + +A simple example of a generic `copy` implementation +for any two `Tensor`s looks like this. + +```c++ +template +CUTE_HOST_DEVICE +void +copy(Tensor const& src, // Any logical shape + Tensor & dst) // Any logical shape +{ + for (int i = 0; i < size(src); ++i) { + dst(i) = src(i); + } +} +``` + +This generic `copy` algorithm addresses both `Tensor`s +with 1-D logical coordinates, thus traversing both `Tensor`s +in a logical column-major order. +Some reasonable architecture-independent optimizations +would include the following. + +1. If the two `Tensor`s have known memory spaces with optimized + access instructions (like `cp.async`), then dispatch to the + custom instruction. + +2. The the two `Tensor`s have static layouts and it can be proven + that element vectorization is valid -- for example, four `LDS.32`s + can be combined into a single `LDS.128` -- then vectorize the source + and destinations tensors. + +3. If possible, validate that the copy instruction to be used is + appropriate for the source and destination tensors. + +CuTe's optimized copy implementations can do all of these. + +## `copy_if` + +CuTe's `copy_if` algorithm lives in the same header as `copy`, +[`include/cute/algorithm/copy.hpp`](../../../include/cute/algorithm/copy.hpp). +The algorithm takes source and destination `Tensor` parameters like `copy`, +but it also takes a "predication `Tensor`" +with the same shape as the input and output. +Elements of the source `Tensor` are only copied +if the corresponding predication `Tensor` element is nonzero. + +For details on why and how to use `copy_if`, +please refer to the +["predication" section of the tutorial](./0y_predication.md). + +## `gemm` + +### What `gemm` computes + +The `gemm` algorithm takes three `Tensor`s, A, B, and C. +What it does depends on the number of modes +that its `Tensor` parameters have. +We express these modes using letters. + +* V indicates a "vector," a mode of independent elements. + +* M and N indicate the number of rows resp. columns + of the matrix result C of the BLAS's GEMM routine. + +* K indicates the "reduction mode" of GEMM, + that is, the mode along which GEMM sums. + Please see the [GEMM tutorial](./0x_gemm_tutorial.md) for details. + +We list the modes of the input `Tensor`s A and B, +and the output `Tensor` C, +using a notation `(...) x (...) => (...)`. +The two leftmost `(...)` describe A and B (in that order), +and the `(...)` to the right of the `=>` describes C. + +1. `(V) x (V) => (V)`. The element-wise product of vectors: Cv += Av Bv. Dispatches to FMA or MMA. + +2. `(M) x (N) => (M,N)`. The outer product of vectors: Cmn += Am B_n. Dispatches to (4) with V=1. + +3. `(M,K) x (N,K) => (M,N)`. The product of matrices: Cmn += Amk Bnk. Dispatches to (2) for each K. + +4. `(V,M) x (V,N) => (V,M,N)`. The batched outer product of vectors: Cvmn += Avm Bvn. Optimizes for register reuse and dispatches to (1) for each M, N. + +5. `(V,M,K) x (V,N,K) => (V,M,N)`. The batched product of matrices: Cvmn += Avmk Bvnk. Dispatches to (4) for each K. + +Please refer to the [GEMM tutorial](./0x_gemm_tutorial.md) +for an overview of CuTe's convention for ordering the modes. +For example, if K appears, it always appears rightmost ("outermost"). +If V appears, it always appears leftmost ("innermost"). + +### Dispatch to optimized implementations + +Just like with `copy`, CuTe's implementations of `gemm` +uses its `Tensor` arguments' types to dispatch +to an appropriately optimized implementation. +Also like `copy`, `gemm` takes an optional `MMA_Atom` parameter +that lets callers override the default `FMA` instruction +that CuTe would select based on the `Tensor` arguments' types. + +For more information on `MMA_Atom` and on specialization of `gemm` +for different architectures, please refer to the +[MMA section of the tutorial](./0t_mma_atom.md). + +## `axpby` + +The `axpby` algorithm lives in the header file +[`include/cute/algorithm/axpby.hpp`](../../../include/cute/algorithm/axpby.hpp). +It assigns to $y$ the result of $\alpha x + \beta y$, +where $\alpha$ and $\beta$ are scalars and $x$ and $y$ are `Tensor`s. +The name stands for "Alpha times X Plus Beta times Y," +and is a generalization of the original BLAS "AXPY" routine +("Alpha times X Plus Y"). + +## `fill` + +The `fill` algorithm lives in the header file +[`include/cute/algorithm/fill.hpp`](../../../include/cute/algorithm/fill.hpp). +It overwrites the elements of its `Tensor` output argument +with a given scalar value. + +## `clear` + +The `clear` algorithm lives in the header file +[`include/cute/algorithm/clear.hpp`](../../../include/cute/algorithm/clear.hpp). +It overwrites the elements of its `Tensor` output argument with zeros. + +## Other algorithms + +CuTe provides other algorithms. +Their header files can be found in the +[`include/cute/algorithm`](../../../include/cute/algorithm) +directory. diff --git a/media/docs/cute/0t_mma_atom.md b/media/docs/cute/0t_mma_atom.md new file mode 100644 index 0000000000..7bdc407413 --- /dev/null +++ b/media/docs/cute/0t_mma_atom.md @@ -0,0 +1,434 @@ +# CuTe's support for Matrix Multiply-Accumulate instructions + +In this file, we explain in detail how we support our GPUs' +Matrix Multiply-Accumulate (MMA) hardware instructions in CuTe. + +MMAs are architecture-specific. +Different generations of GPU architectures +introduce different sets of MMA instructions. +However, CuTe features such as `Layout` +makes it possible to expose MMAs for use in generic CUDA C++ code. +We do this in two steps. + +1. We wrap each MMA's PTX instruction in an "Operation" struct. + +2. For each Operation struct, we define a "Traits" struct + that defines all of the meta-information needed to use the Operation. + +## CuTe MMA Atoms + +CuTe exposes each MMA to generic CUDA C++ code as a pair of structs: +an "Operation" struct, +and an `MMA_Traits` struct templated on the Operation struct type. + +An "Operation" struct exposes the PTX instruction +for that specific operation. +It defines the arguments and interface it expects. +Operation structs have minimal software dependencies -- +it does not use layouts, tensors, or non-standard numeric data types. +Different structs have different names +that describe what the MMA instruction does. +We will explain the naming scheme below. + +A corresponding `MMA_Traits` struct specialization +defines meta-information about the Operation, +such as the compute types, the logical shape of the operation, +and the `Layout`s of threads and values within the operation. +The `MMA_Traits` struct takes the Operation as a template parameter. +CuTe specializes `MMA_Traits` for each Operation type that it supports. + +Together, these two types comprise an "Atom" that decouples the complexity of thread and data layouts from the call site of of the PTX instruction. The Atom's Traits struct exposes information that is relevant to a single MMA operation, no matter the granularity at which it operates. + +CuTe MMA atoms expose the semantics of a single MMA operation. +This is true regardless of the hardware level at which the MMA operates. +CuTe supports MMA atoms that operate at a variety of hardware levels, +including + +* a single thread (e.g., fused multiply-add (FMA) instruction); + +* a quadpair (Volta); + +* a single warp (Ampere); and + +* a warpgroup (Hopper). + +### Operation structs + +#### Location of files + +CuTe provides its Operations structs in the +[`include/cute/arch`](../../../include/cute/arch) +directory, in header files starting with `mma`. + +#### Operation struct's name + +A CuTe Operation struct's name encodes information about + +* its first supported architecture, + +* the M, N, and K dimensions that it accepts, + +* the types that it takes, and + +* the expected A and B layouts. + +For example, the Volta section below will refer to the +`SM70_8x8x4_F32F16F16F32_NT` Operation struct defined in +[`include/cute/arch/mma_sm70.hpp`](../../../include/cute/arch/mma_sm70.hpp). + +* "SM70" refers to Volta. + +* "8x8x4" refers to M = 8, N = 8, and K = 4, + the dimensions of the MMA operation that the quadpair performs + (see below). + +* "F32F16F16F32" refers to the element types + of the four matrix operands A, B, C, and D. + An MMA computes D = C + A * B, + so we read the types from left to right: + D is F32 (`float`), A is F16 (half), + B is F16 (half), and C is F32 (`float`). + +* "NT" means that A is M-major (not transposed) + and B is N-major (transposed). + +#### Contents + +An Operation struct has the following members. + +##### Type aliases + +An Operation struct has four public type aliases: +`DRegisters`, `ARegisters`, `BRegisters`, and `CRegisters`. +For example, the `SM70_8x8x4_F32F16F16F32_NT` Operation struct defined in +[`include/cute/arch/mma_sm70.hpp`](../../../include/cute/arch/mma_sm70.hpp) +defines these as follows. + +```c++ +using DRegisters = float[8]; +using ARegisters = uint32_t[2]; +using BRegisters = uint32_t[2]; +using CRegisters = float[8]; +``` + +This shows how many values each thread will pass into the PTX instruction +for each of the matrices A, B, C, and D. For this Operation, +each thread passes 8 F32 values each for C and D (hence `float[8]`), +and 4 F16 values each for A and B (hence `uint32_t[2]`; +the instruction packs two 16-bit F16 values +in each of the two 32-bit `uint32_t` values). + +##### `fma` static member device function + +An operation struct defines a public `static void fma` function. +It is marked with the `CUTE_HOST_DEVICE` macro, +which adds the `__host__ __device__` annotations. +Different Operations define `fma` to take different numbers of arguments, +depending on the PTX MMA instruction. +The implementation protects use of the PTX instruction with a macro, +and raises an `assert` if `fma` is called when the macro is not defined. +This ensures that tests and examples that use this Operation in an Atom +can still compile, even if the PTX instruction is not available. + +### Traits + +#### Location of files + +CuTe provides its Traits structs in the +[`include/cute/atom`](../../../include/cute/atom) +directory, in header files starting with `mma_traits`. + +#### Contents + +An `MMA_Traits` specialization defines the following public type aliases. + +* `ElementDVal`: Compute type of the D matrix + +* `ElementAVal`: Compute type of the A matrix + +* `ElementBVal`: Compute type of the B matrix + +* `ElementCVal`: Compute type of the C matrix + +* `Shape_MNK`: Logical MxNxK shape of the MMA operation + +* `ThrID`: Logical thread mapping within the single MMA operation + (specifying the quadpair, warp, or warpgroup view) + +* `ALayout`: Mapping of (thread,value) pairs to the logical MxK A matrix + +* `BLayout`: Mapping of (thread,value) pairs to the logical NxK B matrix + +* `CLayout`: Mapping of (thread,value) pairs to the logical MxN C matrix + +#### Example + +The specialization of MMA_Traits for the +`SM70_8x8x4_F32F16F16F32_NT` Operation lives in the header file +[`include/cute/atom/mma_traits_sm70.hpp`](../../../include/cute/atom/mma_traits_sm70.hpp). +It looks like this. + +```c++ +template <> +struct MMA_Traits +{ + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_32b; +}; +``` + +The next section will explain these type aliases in detail. + +## Volta + +This and the following sections show examples of how to construct MMA atoms. +We don't try to explain this for all GPU architectures and MMAs. +Instead, we use selected examples to illustrate the process +of developing new atoms. + +Volta architecture implements an HMMA instruction where a group of 8 threads called a quadpair (QP) collaborate to share data and perform an 8x8x4 (fp32 or fp16) matrix multiply-accumulate. (since a warp is 32 threads wide, it would perform an MMA across 4 QPs for a tile size of 16x16x4). + +We first take a look at how we would take the ISA semantics of thread and data partitioning for the HMMA instruction, and encode it in a Traits struct. The HMMA NT instruction has the thread-data layout: + +

+ HMMA.8x8x4.NT.png +

+ +### Types + +The HMMA NT above uses types: + +```cpp + using ElementDVal = float; + using ElementAVal = half_t; + using ElementBVal = half_t; + using ElementCVal = float; +``` + +The rest of the `MMA_Traits` will be described in units of these types. + +### Shape + +The HMMA NT above has shape 8x8x4: + +```cpp + // Logical shape of the MMA + using Shape_MNK = Shape <_8,_8,_4>; +``` + +### Thread ID + +If the 32 threads in a warp are logically indexed by [0 ... 31], then the above image contains threads [0,1,2,3]U[16,17,18,19]. These threads make up the 0th quadpair. We can write a thread mapping that maps eight logical thread ids [0,1,2,3,4,5,6,7] of the MMA to a quadpair thread index [0,1,2,3]U[16,17,18,19] of a warp. The layout function has 4 elements with a stride of 1 and 2 of those with a stride of 16. With this, we write a layout that represents a quadpair: + +```cpp + // Mapping from (logical thread id) -> (thread idx) + using ThrID = Layout, + Stride<_1,_16>>; +``` + +Again, this layout function maps the logical thread id [0,8) of the MMA operation onto the quadpair thread index [0,4)U[16,20) of a warp. + +### Accumulator Mapping + +Let us look at exactly how the 8 threads within a QP are mapped to the A, B and C matrices. For the C and D matrices, the above image is broken down a bit more below. On the left is shown the whole QP level view, and on the right is shown the values owned by just thread 0. + +

+ HMMA.8x8x4.quadpair.C.png +

+ +The metainformation of this single instruction level view is what we want to encode in CuTe. Specifically, the QP level view in this diagram corresponds to the four MMA traits for [SM70_F32F16F16F32](../../../include/cute/arch/mma_sm70.hpp). These structs contain the `Element` types, the `Shape_MNK`, and the `ThrID` mapping we constructed above. Now, let us take a look at the definition of `CLayout`, the thread-data layout of accumulators. The job of `CLayout` is to construct a mapping between the `(logical_thr_id, logical_val_id)` and `(m, n)` coordinate in the C matrix which can then be used to build up more complicated layouts and operations like the 16x16x4 WMMA. + +We can start constructing a `CLayout` from the picture above. As with any CuTe layout, it is a pair of `Shape` and corresponding `Stride`. Let us just look at the shape for now. We know that the HMMA uses 8 threads each of which own 8 values. Therefore, the shape of our mapping must have a size of 8 along two modes. With this, we have + +```cpp + // (T8,V8) -> (m,n) + using CLayout = Layout, + Stride<_?, _?>; // Stride to be filled in below +``` + +This is not to be confused with the logical 8x8 shape of the C matrix. This is 8-threads by 8-values. We now want to map those to (m,n) coordinates. Since CuTe layouts return indices rather than coordinates, we choose a column-major encoding of the (m,n) coordinates: + +``` +(logical_thr_id, logical_val_id) -> (m, n) == m + n * M +``` + +With this in place, we can start thinking about how to construct the strides in `CLayout`. Let's begin by looking at the strides between threads. Note that +* `(T0,V0)` is located at `(m,n) = (0,0) = 0` +* `(T1,V0)` is located at `(m,n) = (1,0) = 1` +* `(T2,V0)` is located at `(m,n) = (0,2) = 16` +* `(T3,V0)` is located at `(m,n) = (1,2) = 17` +* `(T4,V0)` is located at `(m,n) = (4,0) = 4` +* `(T5,V0)` is located at `(m,n) = (5,0) = 5` +* `(T6,V0)` is located at `(m,n) = (4,2) = 20` +* `(T7,V0)` is located at `(m,n) = (5,2) = 21` + +where `T4`,`T5`,`T6`,`T7` are the 4th,5th,6th,7th logical thread id of the MMA corresponding to thread indices of 16,17,18,19 of the warp (recorded in the `ThrID` mapping!). + +We note that the pattern can be transcribed to a layout. We can find the position of the 8 threads via + +```cpp + using CLayout = Layout, _8>, + Stride, _?>; +``` + +With the exact same approach, we can construct the stride along the `logical value id` mode. +* `(T0,V0)` is located at `(m,n) = (0,0) = 0` +* `(T0,V1)` is located at `(m,n) = (0,1) = 8` +* `(T0,V2)` is located at `(m,n) = (2,0) = 2` +* `(T0,V3)` is located at `(m,n) = (2,1) = 10` +* `(T0,V4)` is located at `(m,n) = (0,4) = 32` +* `(T0,V5)` is located at `(m,n) = (0,5) = 40` +* `(T0,V6)` is located at `(m,n) = (2,4) = 34` +* `(T0,V7)` is located at `(m,n) = (2,5) = 42` + +We note that this pattern can also be transcribed to a layout. We can find the position of the 8 values via + +```cpp + // (T8,V8) -> (m,n) + using CLayout = Layout, Shape <_2,_2, _2>>, + Stride, Stride<_8,_2,_32>>>; +``` + +And that's all! We can verify that each `(tid,vid)` coordinate in this layout is reliably mapped to the correct (encoded) `(m,n)` coordinate. + +In the case of F16 accumulators, the layout is way less complex. Each row of accumulators `(m, :)` is held by a single thread, which makes the layout: + +```cpp + using CLayout = Layout, + Stride<_1,_8>>; +``` + +### A and B Layout Mapping + +A and B matrix layouts depend on whether the sources are transposed or not. The diagram below shows the thread ID to data ownership map for A and B matrices in the case of NT and TN transposes. + +

+ HMMA.8x8x4.quadpair.AB.png +

+ +Let's look at the TN layout for A matrix first (right side in the diagram). Again, there are the same 8 logical threads, but each threads owns only 4 elements this time. The shape of `ALayout` will then be `Shape<_8, _4>`. As for the strides, we again need a similar mapping between `(m, k) == m + k * M`. Looking down the `M` mode, we go from `(T0, V0)` to `(T1, V0)` which is a stride of 1 for all 8 threads. For the `K` mode, as we go across, we go from `(T0, V0)` to `(T0, V1)`, which makes a stride of 8 for all 4 values. Therefore, the A layout is: + +```cpp + // (T8,V4) -> (m,k) + using ALayout = Layout, + Stride<_1,_8>>; +``` + +Source B layout is constructed similarly for the TN HMMA, except that we want write it as `(N,K)` rather than `(K,N)` for convenience. For the strides, as we go across the `N` mode, we go from `(T0, V0)` to `(T1, V0)`, making this a stride of 1 for all 8 threads. As we go down the `K` mode, `(T0, V0)` to `(T0, V1)` which is a stride of 8 for all 4 values. So the B layout is the same as A: + +```cpp + // (T8,V4) -> (n,k) + using BLayout = Layout, + Stride<_1,_8>>; +``` + +The layouts in the case of NT are a bit more complicated (left side of the diagram). Going down the `M` mode of `A`, we see the four values of `T0` first and then we see the four values of `T4`. This means we first have a stride of 1 for 4 values, followed by a stride of 4 from `T0` to `T4`. So we have two sub-strides along the `M` mode. For the `K` mode, as we go across, we simply increment the `thr_id`, keeping `val_id` the same, making the stride 8 for 4 threads. This makes the A layout: + +```cpp + // (T8,V4) -> (m,k) + using ALayout = Layout,_4>, + Stride,_1>>; +``` + +With the `(N,K)` ordering for B, the layout is the same. + +```cpp + // (T8,V4) -> (n,k) + using BLayout = Layout,_4>, + Stride,_1>>; +``` + +For the NN and TT transposes, they are simply combinations of the two layouts we have seen for A and B so far. + +## Hopper + +Now, we are ready to take a look at the much larger GMMA operation (Group MMA) first introduced with Hopper architecture. These MMA instructions operate at the granularity of 128 threads (4 warps), which are collectively referred to as a warpgroup. + +### Thread ID + +In the case of Hopper GMMAs, the thread IDs are assigned based on the simple 1D contiguous layout, which makes `thrID` trivial: + +```cpp +using ThrID = Layout<_128, _1>; +``` + +### Accumulator Mapping + +Accumulators are mapped hierarchically in GMMA, starting from the concept of a core matrix and building up to a layout for the whole C matrix tile. Let's look at this core matrix first. We only consider fp16 accumulators here, but extensions of fp32 accumulators as trivial as we will see later. + +Each core matrix has the layout as shown in the diagram below. +

+ gmma_coremat_cd_fp16.png +

+ +As in the Volta examples, the thread IDs are logical only, and which of the four warps they belong to in the warpgroup is not important. + +Then GMMA tiles this core matrix first vertically along the M mode, and then repeats that column of core matrices along the N mode to construct the full MxN tile. This tiling is shown in the image below. + +

+ gmma_wg_n_slice.png +

+ +With this image, we are again ready to start building the `CLayout` for `SM90_64x128x16_F16F16F16F16_TN` atom. Same as before, we are constructing a mapping between the `(logical_thr_id, logical_val_id) -> (m, n)` coordinate spaces. + +To begin, let's follow the first few threads and values. We immediately see that they are arranged along the `N`-mode with pairs of values and four threads. This gives us + +```cpp +// (T128,V4) -> (M64,N8) +using CLayout = Layout, Shape < _2, ...>>, + Stride, Stride<_64, ...>>>; +``` + +To complete the first 8x8 core matrix, the four threads repeat eight times down the `M`-mode: + +```cpp +// (T128,V4) -> (M64,N8) +using CLayout = Layout, Shape < _2, ...>>, + Stride, Stride<_64, ...>>>; +``` + +Then, as we go to the next core matrix, we wrap back again to `T0`, but this time to `(T0, V2)`. + +```cpp +// (T128,V4) -> (M64,N8) +using CLayout = Layout, Shape < _2, _2>>, + Stride, Stride<_64, _8>>>; +``` + +Finally, we get this entire pattern repeating four times, once for each warp, down the `M`-mode starting at `(m,n) = (16,0) = 16`. where two core matrices that belong to the same warp are stacked on top of each other. This makes the size of the final sub-mode of M 4. As for the stride, this time we go to `(T32, V0)`, which makes it a stride of 32. + +```cpp +// (T128,V4) -> (M64,N8) +using CLayout = Layout, Shape < _2, _2>>, + Stride, Stride<_64, _8>>>; +``` + +This is the full `CLayout` for 64x8 accumulators. The GMMA instructions include 64xN variants with `N = [16,32,64,128,256]` where this 64x8 pattern is repeated giving each thread additional values. As this starts at `(m,n) = (0,8) = 512`, this is easy to account for in our `CLayout`. For example, the 64x128 `CLayout` is + +```cpp +// (T128,V64) -> (M64,N128) +using CLayout = Layout, Shape < _2, _2, _16>>, + Stride, Stride<_64, _8, _512>>>; +``` + +where we see 16 copies of the 64x8 tile. + +### A and B Layout Mapping + +GMMA atoms that consume A and B sources directly from shared memory are a bit interesting. The GMMA Descriptor is constructed on an entore tile of A and/or B data in shared memory rather than being partitioned by threads. That is, every thread sees the entire tile of data and the tile is not reordered so that the descriptor can be constructed on it. In `ALayout` form, this can be expressed + +```cpp +// (T128,V64x8) -> (M64,K16) +using ALayout = Layout>, + Stride< _0, Stride< _1,_64>>>; +``` + +That is, all threads are mapped the to `(m,k) = (0,0) = 0` element and the values (and shape of the values) remains unchanged. The GMMA Descriptor Constructor can then inspect the `(M,K)` layout of this data and create an appropriate GMMA Descriptor or produce an error message saying the data is in an invalid layout for GMMA. diff --git a/media/docs/cute/0x_gemm_tutorial.md b/media/docs/cute/0x_gemm_tutorial.md new file mode 100644 index 0000000000..102010bb6b --- /dev/null +++ b/media/docs/cute/0x_gemm_tutorial.md @@ -0,0 +1,668 @@ +# CuTe dense matrix-matrix multiply tutorial + +This section uses the CuTe functionality to write +a dense matrix-matrix multiply implementation. + +## A simple dense matrix-matrix multiply example + +In this section, we will go through +[this example](../../../examples/cute/tutorial/sgemm_nt_1.cu). +It illustrates a blocked GPU implementation of GEMM +that uses the building blocks of CuTe +to construct global and shared memory layout mappings +and partition threads among them. +This example is closest to the blocked GEMM +that a computer science student might be asked to implement +in a first-year graduate school +or upper-division undergraduate scientific computing course. + +Readers who understand this section may also wish to study +CUTLASS's implementation of the stream-K GEMM algorithm, +which uses many features of CuTe. + +### Filename and high-level interface + +First, let's look at the example's filename `sgemm_nt_1.cu`. +"SGEMM" is the BLAS (Basic Linear Algebra Subroutines) abbreviation +for "Single-precision real, GEneral, Matrix-matrix Multiply." +(If we want to refer to matrix-matrix multiply for all data types, +we say "GEMM.") +The BLAS project started in the 1970s. +You can learn more about its history in Turing Award winner Jack Dongarra's +2004 Oral History interview by SIAM +(the Society for Industrial and Applied Mathematics), +and also in the C++ Standard document [P1417](https://wg21.link/p1417). +The abbreviation SGEMM unpacks as follows. + +* "Single-precision" is Fortran-speak for float. + The BLAS supports four different matrix or vector element types: + + * S for single precision (`float`), + + * D for double precision (`double`), + + * C for complex float (like C++'s `std::complex`, + where each of the real and imaginary components has type `float`), + and + + * Z for complex double (like C++'s `std::complex`). + +* "GEneral" means that the matrix is represented + as a two-dimensional dense array + and not assumed to have any kind of symmetry. + The BLAS supports a variety of matrix representations, + including + + * SY: SYmmetric, + + * HE: HErmitian, + + * TR: TRiangular, + + * GB: General Banded, + + * SB: Symmetric Banded, + + * SP: Symmetric Packed, and + + * TP: Triangular Packed. + +* MM means "Matrix-matrix multiply," as opposed to other operations, + like MV (Matrix-Vector multiply). + +The string "nt" in the filename means that +the first input matrix A is "Not transposed," +while the second input matrix B is "Transposed." +That is, the function computes `C := beta * C + alpha * A * B^T`, +where the superscript T denotes the transpose of the matrix. +(We never change the input matrix in place or +store its entire transpose explicitly. +Instead, we reinterpret its data in place.) + +GEMM's TRANSA and TRANSB arguments lets users specify +the transpose or Hermitian transpose (if complex) +of either or both input matrices A or B. +It turns out that implementations favor this "NT" case, +along with "TN" (A is Transposed, B is Not transposed). +We will explain why below. + +As described, the original BLAS GEMM specifies +the dimensions of its matrices +as A is M x K, B is K x N, and C is M x N. +Out of convenience, CuTe interprets A +as M x K, B as N x K, and C as M x N. Instead of row-major or column-major (or Transposed +and Not-Transposed like above), we like to be more specific with M-major, N-major, or K-major. +Regardless, we'll still use the BLAS "NT" notation for high-level descriptions +of kernels when it's appropriate. + +Now, let's look at the code. +We'll start with the kernel entry point `gemm_device` +at the top of the file. + +```c++ +template +__global__ static +__launch_bounds__(decltype(size(CThreadLayout{}))::value) +void +gemm_device(MShape M, NShape N, KShape K, + TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, + TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, + TC * C, CStride dC, CBlockLayout , CThreadLayout tC, + Alpha alpha, Beta beta); +``` + +There are many template parameters; +we'll explain them all in due time. + +`TA`, `TB`, and `TC` are the element types +of the matrices `A`, `B`, and `C`, respectively. +The two scalar constants `alpha` and `beta` +are part of what GEMM computes: `C = beta * C + alpha * A * B`. +Unlike the (traditional Fortran and C) BLAS, +CuTe lets you mix different matrix element types and/or scalar types. +The compiler will help, but it's somewhat up to you +to use types that are safe and efficient on the GPU. +For example, a custom arbitrary-precision real type +that does dynamic allocation inside may not work on the GPU at all. +Even if it does, it may not perform well. + +This leaves five kinds of things to explain: + +1. Shapes + +2. Strides + +3. Block layouts + +4. Thread layouts + +5. Launch bounds + +### Shapes + +The original Fortran BLAS GEMM lists the matrices' dimensions +in the order M, N, K. CuTe also uses this convention. +The "MShape" is just M, +the NShape is just N, +and the KShape is just K. +In this example, they are dynamic (run-time) values +defined at the top of the `gemm` host function +that invokes the device kernel. + +```c++ +// Define shapes (dynamic) +auto M = int(m); +auto N = int(n); +auto K = int(k); +``` + +Note that the function takes M, N, and K. +It doesn't take the shapes of the three matrices separately, +as (say) three different `Shape` objects. +This is because matrix-matrix multiply constrains the shapes. + +There's nothing mysterious about `int` here; +it's the usual C++ built-in integral type. +`auto M = int(m)` is a way to say +"convert `m` to an `int` if it's not already an `int`, +and assign it to the freshly declared variable `M`." +CuTe also has a capitalized `Int` templated type +for representing values as compile-time constants. +For example, `Int<5>` represents a compile-time `int` value 5. +(CuTe implements these as subclasses +of the C++ Standard Library class `std::integral_constant`.) +The above `gemm_device` function is templated on the types +of M, N, and K; this shows that CuTe can represent dimensions +as either run-time or compile-time values. + +If you're familiar with the mdspan class going into C++23, +you might notice that CuTe represents shapes +a bit differently from mdspan. +mdspan uses `extents` +to represent a shape. +The `Extents` are zero or more compile-time values +(see below) representing the dimensions in the shape. +The `Extents...` are "non-type template parameters" (NTTPs) -- +that is, they are not types, but compile-time values of type `size_t`. +If you use the special reserved `size_t` value `std::dynamic_extent` +as an extent value, +the resulting dimension is a run-time value +and is stored in the `extents` instance. +Any other extent value is a compile-time value +that is encoded in the extents type itself. +In contrast, CuTe represents a shape as `Shape`. +The `Types...` are actual types, not NTTPs. +A built-in integral type like `int` or `uint64_t` +denotes a run-time dimension that is stored in the `Shape` instance, +while a compile-time value like `Int<5>` +encodes a compile-time dimension. +For example, the CuTe equivalent of +`extents` +is `Shape, int, Int<5>>`. + +#### Compile-time-ness of values + +C++ values have three levels of "compile-time-ness": + +1. dynamic (run-time) values, + +2. constexpr values, and + +3. static (compile-time) values. + +(Rather than saying "C++ has," +it's more accurate to say "C++17 has." +C++20 introduces `consteval` or "immediate" functions, +which make attempting to evaluate the function at run time +(any call not in an unevaluated context) a compiler error. +We'll ignore those for this tutorial, +since CuTe only requires C++17.) + +The `constexpr` keyword was introduced in C++11. +It means something like +"the compiler can evaluate this expression at compile time." +It does NOT mean "the compiler MUST evaluate this at compile time." +If you use a `constexpr` expression in a `static_assert` +or as a non-type template argument, +then the compiler must evaluate the expression at compile time. +However, for `constexpr` occurring in other places, +the compiler may choose to store the value in registers or memory, +and/or do computations with the value at run time. +In some cases, the compiler must do that. +The following example shows that the compiler +might need to store `constexpr` values in memory sometimes. + +```c++ +// Some function defined in a different compilation unit. +extern int foo(int* x); + +int bar() +{ + constexpr int value = 42; // a compile-time constant + + // Even constexpr variables have a sizeof, + // because we still might need to take their address. + static_assert(sizeof(value) == 4); + + // Compiler can't inspect foo to see how it uses the value, + // so it has to store the value in some memory location + // so that we can pass its address to the function. + return foo(&value); +} +``` + +"Static" is an unfortunately overloaded term in C++. Sometimes it means "the opposite of instance," like a "static function" or "static member" of a class. (Some programming languages, like Java, say "class method" to refer to a "static function of a class.") That's not what we mean here. Instead, we mean "part of a compile-time type." For example, `Int<1>` encodes the value 1 at compile time, as part of the type of a templated class `Int`. `Int<3>` and `Int<4>` have different types. You can get the value of of the type like this: `Int<3>::value`. (The `value` is a `static constexpr` member of the class, where "static" means "opposite of instance.") As soon as you go from `Int<3>` to `Int<3>::value`, you've gone from (3) above (a compile-time value) to (2) above (a `constexpr` value). In some situations, this may mean that the compiler treats it as a run-time value. + +#### Strides + +We define a layout using both shapes and strides. +The shape just tells you the dimensions (modes, etc.) of the array. +The strides tell you the mapping from a multidimensional index +into a one-dimensional offset. +Here, we're describing the shapes and strides +of the "global" matrices A, B, and C. +The example defines the global matrices' strides +near the top of the `gemm` function. + +```c++ +// Define strides (mixed) +auto dA = make_stride(Int<1>{}, ldA); // (dM,dK) +auto dB = make_stride(Int<1>{}, ldB); // (dN,dK) +auto dC = make_stride(Int<1>{}, ldC); // (dM,dN) +``` + +To evaluate this mapping for a given multidimensional index, take the dot product of the indices with the strides. For example, the offset of `A(index_m, index_k)` is `index_m * 1 + index_k * ldA`. Note the implications for the compile-time-ness of the offset. Any run-time value among either the shape or the strides makes the offset a run-time value. Of course, if a particular stride is a compile-time constant (especially 1), it's easier for the compiler to optimize the arithmetic and result. + +Note that in the original source code, +this example is missing the comments after each line. +We've added them in here, +as they stir a brief digression about shapes and modes. +The comment after B says (dN, dK), not (dK, dN). +This means that B is treated as an N x K matrix +instead of a K x N matrix. +As mentioned, CuTe follows the convention +that the meaning of matrix modes is +(M,K) for A, (N,K) for B, and (M,N) for C. +In particular, CuTe's convention is that +"the reduction mode is outermost." +The "reduction mode" of `Shape` is K. +That's the mode over which we do a reduction, +that is, sum up products of matrix entries. +The K mode disappears in the output C. +"Outermost" here means "rightmost" +(literally, appearing rightmost in the list M, N, K). +Note that the shapes form a kind of Einstein tensor notation. +GEMM does Shape = Shape * Shape. +In Einstein notation, the repeated index indicates +a sum of that term over all values of K. + +We say in general that the leftmost mode is the "inner(most)" mode, +and the rightmost mode is the "outer(most)" mode. +This is because, +along with CuTe's convention of thinking of arrays as logically column major, +the leftmost mode is most commonly the mode with the most spatial locality. +It's very often the "most contiguous" mode. +For this reason, it's "the mode that we want in the innermost loop" +(in the nesting of loops that implements GEMM). +This is why we call it the "innermost" mode. +Its contiguity means that also call the innermost mode the "vector mode." + +The vector mode also has special meaning: +it contains all of the information needed +to execute the smallest possible computation or communication operations +on hardware, that is, what CuTe calls the "atoms." + +Modes are like units conceptually. +For example, you shouldn't mix M-mode indices with K-mode indices. +However, CuTe does nothing to enforce this. +(For example, CuTe does not require use of "tagged" index types. +Indexing works with the usual integer types.) + +The previous paragraph relates to shapes, not strides. +Returning to the strides, the above code describes these strides as "mixed." +This means that they include both run-time and compile-time values. +For example, the stride between A(m, k) and A(m+1, k) is `Int<1>`, +a compile-time value 1. The stride between A(m, k) and A(m, k+1), +however, is `ldA`, the "leading dimension of A," a run-time value. +The "leading dimension" of a matrix +refers to the stride between consecutive columns of a column-major matrix +(where the stride between consecutive rows is 1), +or the stride between consecutive rows of a row-major matrix +(where the stride between consecutive columns is 1). +This is a naming convention from the BLAS +and libraries that use it, like LAPACK. +For the purpose of this tutorial, it's just a naming convention +for "the stride that isn't the compile-time constant 1." + +#### M-major, N-major, K-major + +Note that we haven't uttered the phrases "column-major" or "row-major" here. This is where the experience of a BLAS user diverges from the experience of a BLAS implementer. BLAS users speak of "column-major" and "row-major" layouts. C++23's `mdspan` class encodes these as `layout_left` resp. `layout_right`. However, we don't speak of "column-major" or "row-major" in our GEMM implementations. + +We say that a matrix is "M-major" if it is stride 1 in the M-mode, "N-major" if it is stride 1 in the N-mode, or "K-major" if it is stride 1 in the K-mode. In the above code, A has shape (M, K) and strides (1, ldA). Since A has stride 1 in the M mode, we say that A is "M major." B has shape (N, K) and strides (1, ldB), so B is "N-major." Similarly, C has shape (M, N) and strides (1, ldC), so C is "M major." + +How do we translate this into the BLAS user's experience? +The following table illustrates for B and C. +(Throughout the table, "Impl" stands for "implementation.") + +Note that the implementation reverses the order of B's modes, +and flips B's strides. +Recall that one evaluates a layout +by taking the dot product of the indices and strides. +Thus, reversing the order of both the modes and the strides +does not change this evaluation. + +| Matrix | User's shape | User's layout | User's strides | Impl layout | Impl shape | Impl strides | +| --- | --- | --- | --- | --- | --- | --- | +| C | M x N | Column major | (1, LDC) | M-major | (M, N) | (1, LDC) | +| A | M x K | Column major | (1, LDA) | M-major | (M, K) | (1, LDA) | + +What about the matrix B? We explained above that B is N-major. How would that translate back into the BLAS user's experience? We take a hint here from the filename including "nt." The "nt" part of the name means that A is not transposed, while B is transposed. The BLAS convention (see e.g., [the documentation for DGEMM](https://netlib.org/lapack/explore-html/d1/d54/group__double__blas__level3_gaeda3cbd99c8fb834a60a6412878226e1.html)) is that if you take the transpose, then the dimensions refer to the transpose ("with op( A ) an m by k matrix, op( B ) a k by n matrix and C an m by n matrix"). Thus, this example actually computes `C = beta * C + alpha * A * B^T`, where `B^T` is an K x N matrix with strides (LDB, 1). The user's "original" matrix B is thus N x K, with strides (1, LDB) -- that's a column-major layout. (Reversing the modes and the strides preserves the layout, since evaluating the layout mapping just takes the dot product of indices and strides.) This lets us expand the above table to include B. + +| Matrix | Transposed? | User's shape | User's layout | User's strides | Impl layout | Impl shape | Impl strides | +| --- | --- | --- | --- | --- | --- | --- | --- | +| C | No | M x N | Column major | (1, LDC) | M-major | (M, N) | (1, LDC) | +| A | No | M x K | Column major | (1, LDA) | M-major | (M, K) | (1, LDA) | +| B | Yes | N x K | Column major | (1, LDB) | N-major | (N, K) | (1, LDB) | + +CuTe developers say: "In CuTe, you can't tell transposed +apart from non-transposed, MN-major from K-major, etc. +without inspecting the strides." +It's now a bit more clear what that means. +CuTe doesn't see whether A or B are transposed. +Instead, CuTe sees shapes and strides. +A CuTe developer must reason backwards from the shapes and strides +in order to see what the BLAS user sees. + +Why does CuTe do this? Consider that matrix multiply performs a reduction in the K-mode. From the user's perspective, it's reducing across rows of the first input matrix, but across columns of the second input matrix. If we instead mentally flip the modes of the first input matrix, then the implementation reduces over columns (the K mode) of both input matrices. This leads to two cases in which the implementation can effectively treat both input matrices in the same way. (If you call it with A and B reversed, it should even give the same results for these cases.) + +| Case | User asks for A | User asks for B | Abbreviation | +| --- | --- | --- | --- | +| A is M major, B is N major | Not transposed | Transposed | NT | +| A and B are both K major | Transposed | Not transposed | TN | + +This is why an introductory example starts with NT or TN. +For a summary of the four different transpose options for A and B, +and their corresponding implementation layouts, +please see the table below. + +| Transpose abbreviation | User sees A transposed? | User sees B transposed? | A's impl layout | B's impl layout | +| --- | --- | --- | --- | --- | +| NT | No | Yes | M major | N major | +| TN | Yes | No | K major | K major | +| NN | No | No | M major | K major | +| TT | Yes | Yes | K major | N major | + +#### MN-major and K-major + +As we mentioned above, there are two "preferred arrangements," TN and NT. In the TN arrangement, both A and B are K-major. In the NT arrangement, A is M-major and B is N-major. Even though the two stride-1 modes in NT have different names, it's still the leftmost mode for both A and B that has stride 1. Thus, we can think of the NT arrangement as "MN-major," analogous to how the TN arrangement is "K-major." + +The two preferred arrangements tend to work themselves into implementations, particularly when they use hardware instructions for accelerating matrix multiplies of blocks. In some cases, the hardware instruction may require NT (MN-major) or TN (K-major). For NN or TT, such instructions would require an intermediate transpose -- for example, when loading from global memory to shared memory. + +### Block layouts + +Efficient matrix multiply implementations loop over blocks. +For example, a typical GPU implementation strategy +is for each thread block to iterate over some number of blocks. +In the example, this loop occurs near the end of `gemm_device`. + +```c++ +// TUTORIAL: Example of a very simple compute loop +// Data is read from global to shared memory via the tA|tB partitioning +// gemm(.) operates on the shared memory directly via the tC partitioning + +auto k_max = size<2>(tAgA); + +for (int k = 0; k < k_max; ++k) +{ + // Copy A and B blocks from global memory to shared memory. + copy(tAgA(_,_,k), tAsA); + copy(tBgB(_,_,k), tBsB); + + // On some architectures, copy may be asynchronous. + // This may call for extra synchronization instructions + // beyond just __syncthreads(). + + __syncthreads(); + + // Compute gemm on shared memory input and register accumulator. + // The "epilogue" after this loop will copy the accumulator + // from the register file into global memory. + gemm(tCsA, tCsB, tCrC); + + __syncthreads(); +} +``` + +We will explain the notation in this loop below. The important things to remember are that the coordinate `k` loops over the blocks which the calling thread is supposed to compute, the `copy` functions copy A resp. B blocks from global memory (the first argument) to shared memory (the second argument -- same as C++'s `std::copy`, but the opposite of `memcpy`), and the `gemm` function computes C += A * B on the shared memory blocks. + +It turns out that copy takes an optional first argument, the "atom," as in the following. + +```c++ +copy(atom, source, destination); +``` + +The "atom" is metadata that explains how to do the copy operation. + +There are a few topics to push onto the stack. + +The copy function call shows a notation for taking slices of a tensor. A CuTe `Tensor` is a multidimensional array view. It consists of a pointer and a `Layout`. You can learn more about `Tensor`s elsewhere in CuTe's documentation, but for now, please note that `tAgA(_,_,k)` means "create a Tensor that views (i, j, k) for all valid i, all valid j, and a specific value of k." The result has rank one less than the original Tensor. CuTe's underscore means the same thing as a single stand-alone colon in Fortran or Matlab. Note also that CuTe uses the same notation for slices as for tensor indexing. The implementation can distinguish the two cases by checking whether any of the arguments is an underscore. In contrast, the C++23 class mdspan uses a separate function, `submdspan` (not in C++23, and proposed for C++26; see [P2630](https://wg21.link/p2630)), for slicing. + +Fully understanding what `copy` and `gemm` do calls for learning about thread layouts as well, so we will wait to explain them completely. For now, note that these functions are implicitly parallel, as they are called collectively by all threads in a thread block. + +The block dimensions are defined near the top of the host function `gemm`. + +```c++ +// Define block sizes (static) +auto bM = Int<128>{}; +auto bN = Int<128>{}; +auto bK = Int< 8>{}; +``` + +We see that these are fully compile-time dimensions. This is often the case, especially when we use hardware instructions that only work for certain problem dimensions. Three lines of code immediately below these construct the block layouts. + +```c++ +// Define the block layouts (static) +auto sA = make_layout(make_shape(bM,bK)); +auto sB = make_layout(make_shape(bN,bK)); +auto sC = make_layout(make_shape(bM,bN)); +``` + +Here, the block layouts just come from the block dimensions. A Layout has two things: a Shape, and Strides. If the caller does not provide Strides, then CuTe computes Strides corresponding to the default "column-major" arrangement of data. This just happens to match the global matrices' layouts, but in general doesn't have to. For example, in the NN or TT cases, we may want to transpose one of the input matrices when copying from global memory to shared memory. + +The example "comments out" some code that prints all the layouts on "thread 0" of each thread block. If you enable the printing code and run the example, it will print all the layouts. For example, sA prints as + +``` +sA +(_128,_8) +(_1,_128) +``` + +and sB prints as + +``` +sB +(_128,_8) +(_1,_128) +``` + +consistently with the definitions above. + +If you have looked at other GEMM examples in CuTe, you might be wondering about hardware matrix-matrix multiply instructions. Those instructions tend to require certain values for shapes and strides, that may be a function of the matrix's element type. CuTe knows about these instructions and their required shapes and strides. We will go into more detail about that elsewhere. + +The `gemm_device` top-level kernel uses these block layouts to allocate shared memory buffers for A and B tiles. + +```c++ +// Shared memory buffers +__shared__ TA smemA[cosize_v]; +__shared__ TB smemB[cosize_v]; +``` + +Note how the shared memory buffers' sizes depend only on the A resp. B layouts (and element sizes). What's a `cosize_v`? The "`_v`" is a C++ naming convention that specifies a function from one or more template argument(s), to a value. In this case, it's a number of elements. A layout is a function from a set of multidimensional coordinates to a set of one-dimensional array offsets. It's a function, so we can speak of its domain and codomain. The "cosize" of a layout is the size of its codomain. (See e.g., CuTe's implementation of `Layout`.) If we want to allocate a linear array, for which all the offsets produced by a layout are valid, then we can use the cosize of the layout as the length of the array (in terms of number of elements, not in terms of number of bytes). + +### Thread layouts + +CuTe uses a `Layout` to describe the assignment of threads to work items. +In this example, the host function `gemm` constructs the thread layouts +for A, B, and C. + +```c++ +// Define the thread layouts (static) +auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); +auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); +auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); +``` + +That is, the thread layout for the A read is M-major 32x8, for the B read is N-major 32x8, and for the C compute/write is M-major 16x16. These thread layouts will partition the data for their respective stages. + +#### The example uses compile-time thread and block layouts + +Note that the device function `gemm_device` insists that all the thread and block layouts are static -- that is, known at compile time. You can see this from the `CUTE_STATIC_ASSERT` statements near the top of `gemm_device`. `CUTE_STATIC_ASSERT` is a wrapper for `static_assert`, which fails at compile time if its condition is `false`. + +```c++ +// Preconditions +CUTE_STATIC_ASSERT(is_static::value); +CUTE_STATIC_ASSERT(is_static::value); +CUTE_STATIC_ASSERT(is_static::value); + +CUTE_STATIC_ASSERT(is_static::value); +CUTE_STATIC_ASSERT(is_static::value); +CUTE_STATIC_ASSERT(is_static::value); +``` + +Use of static layouts has two advantages. First, it makes it easier to prove correctness of the algorithm. If the code compiles, it's likely correct. (On the other hand, new CuTe users may find themselves doing more debugging at compile time than they have before.) Second, it makes it easier and faster for CuTe to dispatch to the correct optimized implementations (called "atoms" -- see below) for copying blocks and performing matrix multiplies. + +#### The example's block gemm is parallel over elements of C + +In the actual device function, `tC` has layout `CThreadLayout`. You might recall that the kernel function `gemm_device` uses `CThreadLayout` to derive the launch bounds, specifically the maximum number of threads per block. The launch bounds show up in the declaration of `gemm_device`. + +```c++ +template +__global__ static +__launch_bounds__(decltype(size(CThreadLayout{}))::value) +void +gemm_device(MShape M, NShape N, KShape K, + TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, + TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, + TC * C, CStride dC, CBlockLayout , CThreadLayout tC, + Alpha alpha, Beta beta); +``` + +The "size" of `CThreadLayout` is the total number of threads, 16 * 16 = 256. (We take `::value` because the size is actually `Int<256>`, a compile-time constant with a `static constexpr int value = 256` member.) This suggests that the block gemm function (in the loop over blocks) parallelizes over elements of the C block. We can see this as well from the kernel launch (at the end of the `gemm` host function), which uses the size of `CThreadLayout` as the block dimension. + +```c++ +// Define the thread layouts (static) +auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); +auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); +auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); + +dim3 dimBlock(size(tC)); +dim3 dimGrid(ceil_div(size(M), size(bM)), + ceil_div(size(N), size(bN))); +gemm_device + <<< dimGrid, dimBlock, 0, stream >>> + (M, N, K, + A, dA, sA, tA, + B, dB, sB, tB, + C, dC, sC, tC, + alpha, beta); +``` + +Note that dimBlock is single-dimensional (despite being a dim3), as the size of a layout is a single value. We can see this also because the example only ever uses `threadIdx.x`, not `threadIdx.y`. Yet, C's thread layout has shape (16, 16). What's with that? Recall that a thread layout maps from a "logical" coordinate space (possibly multidimensional tuples of indices) to (one-dimensional) integer indices. In this case, `CThreadLayout` maps from pairs of indices in the Cartesian product space {0, 1, 2, ..., 15} x {0, 1, 2, ..., 15}, to one-dimensional indices 0, 1, 2, ..., 255. The latter, the output of `CThreadLayout`, is the actual thread index `threadIdx.x` in this case. `CThreadLayout` has only a shape (16, 16) and no nondefault strides, so it uses CuTe's default column-major arrangement (with strides (1, 16) in this case). + +#### What does `local_tile` do? + +The following code near the top of `gemm_device` +operates on the "global" (input and output) matrices A, B, and C +(where mA, mB, and mC are their Tensor representations). + +```c++ +// Get the appropriate blocks for this thread block -- +// potential for thread block locality +auto blk_shape = make_shape(size<0>(sA), size<0>(sB), size<1>(sB)); // (BLK_M,BLK_N,BLK_K) +auto blk_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) + +Tensor gA = local_tile(mA, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) +Tensor gB = local_tile(mB, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) +Tensor gC = local_tile(mC, blk_shape, blk_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) +``` + +There are two new features here: + +* `make_coord`, which returns a `Coord`, a multidimensional index which can be used as the input of a `Layout`; and + +* `local_tile`, which we will explain below. + +The `Coord`(inate) `blk_coord` refers to the set of blocks (indexed by k -- the underscore here indicating a free parameter) our thread block will access. (The index k here doesn't mean the K mode; it's the same index as in the loop over blocks that does the computation.) + +If we print out the `gA`, `gB`, and `gC` layouts, we get the following. + +``` +gA +(_128,_8,512) +(_1,5120,40960) + +gB +(_128,_8,512) +(_1,5120,40960) + +gC +(_128,_128) +(_1,5120) +``` + +All of these layouts come from the original input or output matrices A, B, and C. Thus, they preserve the original strides, which are all the same in this example (when using default problem dimensions), 5120. This is most easily seen in the gC layout. For the other layouts, there is a clue in 5120 * 8 = 40960. That is, every time we increase k by one, we "skip over 8 columns" of the global matrix, over to the next block of data. This illustrates an important feature of CuTe, that it can view the same data with different modes and/or strides, as a way to identify parallelism or locality. + +## Next steps + +The above "simple GEMM" example's performance on many problems +is asymptotically optimal +with respect to the GPU's floating-point throughput. +Getting nearly peak performance +relative to the GPU's floating-point throughput, +for a wider variety of problem dimensions, +calls for more advanced techniques. +Please refer to other examples in this repository +to learn more about those techniques. +For example, the +[predication section of the tutorial](./0y_predication.md) +explains what to do if a matrix tiling +doesn't perfectly divide the matrix. + +### Implement GEMM as generalized tensor constraction (GETT) + +"GETT" here stands for "general(ized) tensor times tensor," +a tensor contraction. + +CuTe permits matrices to have nested `Layout`s. +For example, a matrix A can have a nested `Layout` for its M and N modes. +This means that we can use a "matrix" (`Tensor` with two modes) +to represent any `Tensor`. +This amounts to a "native hierarchical representation." + +As a result, we can implement GETT by using +our existing GEMM implementation layers, +with a little bit of fancy custom predication for the K mode. +This is because the stride type of A +and the problem shape itself +are CuTe Shapes and Strides. +This lets us represent the hierarchical modes +of a tensor contraction problem +(which still fundamentally only have 4 modes -- +batch mode, +two outer modes (one for A and one for B), +and one reduction mode -- +each of which can now have as many nested modes as you want +for the contraction's inputs). +We thus implement GETT as contraction just in one mode -- the K mode. +However, K itself can be hierarchical and can have noncontiguous strides. +We can reorder the modes such that all contraction modes +become a single, possibly hierarchical K mode in the kernel. +This is how we would encode a contraction in multiple modes at once. diff --git a/media/docs/cute/0y_predication.md b/media/docs/cute/0y_predication.md new file mode 100644 index 0000000000..f764508bf1 --- /dev/null +++ b/media/docs/cute/0y_predication.md @@ -0,0 +1,217 @@ +# Predication: What to do when tiling isn't perfect + +The [GEMM tutorial](./0x_gemm_tutorial.md) shows how +we compute a matrix-matrix multiply +by iterating over tiles of the input matrices and output matrix. +The examples all assume that the tiles fit evenly into the matrices, +with no remainder. +What do we do if this is not the case? +For example, we might want to tile a 41 x 55 matrix into 4 x 8 tiles, +but 41 / 4 is 10 remainder 1, and 55 / 8 is 6 remainder 7. +What do we do with those "leftover" parts of the matrix? + +Another way to say this, is that `logical_divide` +(CuTe's way of tiling layouts) "rounds up." +For example, if `N` is the layout (1000, 1) and `B` is the layout (128, 1), +then `logical_divide(N, B)` is the layout ((128, 8), (1, 128)). +This effectively rounds up the original shape N = 1000 +into an 128 x 8 matrix (as if N = 1024). +What about those last 24 elements, +that aren't part of the original data? + +The idiomatic CuTe way to solve this problem is through "predication." +Rather than trying to reason about the "remainder tiles," +CuTe instead rounds up, but only tries to access data in each tile +that are part of the matrix. +This corresponds well with how our GPUs optimize: +branches without warp divergence are relatively fast. +It also matches the usual CUDA idiom +when dividing N work items in 1-D fashion over B thread blocks: +first test if "my thread" is out of bounds before doing work. + +There are a few ways to figure out +which elements need to be predicated. +In-kernel GEMMs like to do this in the following way. + +```c++ +// Create the predicate tensor +Layout idA = make_layout(shape(A)); // e.g. 1000:1 +Layout idAB = logical_divide(idA, B); // e.g. (128,8):(1,128) + +Tensor pred = make_tensor(shape(idAB)); +for (int i = 0; i < size(pred); ++i) { + pred(i) = idAB(i) < size(A); +} + +// ... intervening code ... + +// Use the predicate tensor. c is some coordinate. +// This code would likely live inside some algorithm. +if (pred(c)) { copy(idAB(c), smem(c)); } +``` + +The general procedure is that we + +1. create an "identity" layout (`Layout idA = make_layout(shape(A))`, + in the above example) with the same shape as our original data; + +2. repeat the same tiling/partitioning/slicing (possibly rounding up) + on that identity layout (`Layout idAB = logical_divide(idA, B)`); + +3. create a "predicate tensor" by comparing the coordinates + of that reference layout with the bounds of the original layout; + and then + +4. use the predicate tensor to mask off accesses to out-of-bounds elements. + +For example, suppose that we've partitioned A and B tiles +across threads as follows. + +```c++ +Tensor tAgA = local_partition(gA, tA, thread_idx); // (THR_M,THR_K,k) +Tensor tAsA = local_partition(sA, tA, thread_idx); // (THR_M,THR_K,PIPE) + +Tensor tBgB = local_partition(gB, tB, thread_idx); // (THR_N,THR_K,k) +Tensor tBsB = local_partition(sB, tB, thread_idx); // (THR_N,THR_K,PIPE) +``` + +`tAgA` and `tBgB` partition the global A resp. B matrices over threads, +and `tAsA` and `tBsB` partition the shared memory tiles of A resp. B over threads. + +The following code creates predicate tensors +corresponding to `tAgA` and `tBgB`. +They will be computed once in the prologue. +and will be used to mask off instructions in the inner loop. + +```c++ +Tensor tApA = make_tensor(make_shape (size<0>(tAgA), size<1>(tAgA)), + make_stride( Int<1>{}, Int<0>{})); +Tensor tBpB = make_tensor(make_shape (size<0>(tBgB), size<1>(tBgB)), + make_stride( Int<1>{}, Int<0>{})); +``` + +We're only thread-parallelizing over the leftmost (row) dimension, +so we only need to predicate over the leftmost dimension. +Thus, we can make the rightmost (column) stride zero, +since we will never actually address the rightmost dimension. + +The following code creates "two-dimensional identity tensors" +that map coordinates (m,k) -> (m,k) +for the tile of data within the thread block. + +```c++ +Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) +Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) +``` + +The following lines then tile and partition +the two reference tensors +in exactly the same way the data were tiled and partitioned +into `tAsA` and `tBsB`. + +```c++ +Tensor tAcA = local_partition(cA, tA, thread_idx); +Tensor tBcB = local_partition(cB, tB, thread_idx); +``` + +Tiling and partitioning affect the offset and domain, +but not the codomain of the tensors, +so we're left with tensors that map `(thr_m,thr_k) -> (m,k)` +where `(thr_m,thr_k)` is this particular thread's subtensor of the tile +and `(m,k)` is the original codomain: a coordinate into the original tile. + +The unrolled loops in the code below then compare +the m- and n-coordinates of those tensors with our known maximums +to mask off elements we are not allowed to access. + +```c++ +Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) +Tensor tAcA = local_partition(cA, tA, thread_idx); + +Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) +Tensor tBcB = local_partition(cB, tB, thread_idx); + +// Populate +CUTE_UNROLL +for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(m,0)) < m_max_coord; +} +CUTE_UNROLL +for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(n,0)) < n_max_coord; +} +``` + +Those last `for` loops fill in the two predicate tensors. +In this case, we only need to predicate over the leftmost dimension, +so we only address `(m,0)` resp. `(n,0)`. + +We can then use the predicate tensors in `copy_if` +to copy only the elements for which the corresponding +predicate tensor elements are nonzero. + +```c++ +// Prefetch k_tile=0, gate these on k_residue as well +CUTE_UNROLL +for (int k = 0; k < size<1>(tAsA); ++k) { + if (get<1>(tAcA(0,k)) >= -k_residue) { // some other condition on the column index + copy_if(tApA, tAgA(_,k,0), tAsA(_,k,0)); + } +} + +CUTE_UNROLL +for (int k = 0; k < size<1>(tBsB); ++k) { + if (get<1>(tBcB(0,k)) >= -k_residue) { // some other condition on the column index + copy_if(tBpB, tBgB(_,k,0), tBsB(_,k,0)); + } +} +``` + +Here are some advantages of this "reference tensor" approach. + +1. It doesn't depend on the layout/strides of the tensor + being predicated, just the logical bounds being imposed. + +2. The partitioning stage can be anything. + +3. It naturally extends to any-dimensional predication. + +4. It's a natural generalization of a typical CUDA 1-D + parallel vector access pattern, + which computes an access index `k` + (e.g., as `blockDim.x * blockIdx.x + threadIdx.x`) + and then predicates access to the vector's `k`-th element + on whether `k` is in bounds. + +As an example of (3), the epilogue predication does exactly the same thing, + +```c++ +// Repeat with a tensor of coordinates for predication +Tensor cC = make_identity_tensor(make_shape(size<0>(gC), size<1>(gC))); +Tensor tCcC = thr_mma.partition_C(cC); + +const bool isBetaZero = (beta == 0); + +CUTE_UNROLL +for (int i = 0; i < size(tCrC); ++i) { + if (elem_less(tCcC(i), make_coord(m_max_coord,n_max_coord))) { + tCgC(i) = isBetaZero ? alpha * tCrC(i) : alpha * tCrC(i) + beta * tCgC(i); + } +} +``` + +but with the mma responsible for the tiling/partitioning `tCcC` +so that the reference subtensor matches the accumulator's subtensor. +Then, the reference subtensor is predicated against the `if` bounds +(in both m- and n-coordinates) inside the `for` loop. + +Another way to explain this is that we don't modify the tiles +to give you the "right" extents so that you never overrun. +Instead, we let you query the original coordinate +to see if that coordinate overruns. +This avoids all branching and variable/dynamic loop bounds +(thus maintaining load balance and synchronicity, +both very important in-kernel) in favor of predication. +It's also general enough to extend to all ranks, +all layouts of threads and data, +and all tiling/partitioning patterns. diff --git a/media/docs/cutlass_3x_backwards_compatibility.md b/media/docs/cutlass_3x_backwards_compatibility.md new file mode 100644 index 0000000000..7be2a91bf8 --- /dev/null +++ b/media/docs/cutlass_3x_backwards_compatibility.md @@ -0,0 +1,473 @@ +[README](/README.md#documentation) > **CUTLASS 3.0 GEMM Backwards Compatibility** + +# CUTLASS 3.0 GEMM Backwards Compatibility + +Although CUTLASS 3.0 restructures the GEMM hierarchy and introduces new types for the +threadblock layer and below, we intend the entire source code to be usable in user applications. +We expect users to be able to `#include` any source file from CUTLASS 3.0, whether +they implement the 2.x or the 3.x API, without breaking user builds. This means that a single +translation unit should be able to contain any valid kernel regardless of its API version. The +sections below discuss how `device` and `kernel` layer type names are made compatible across the +two API versions, and what the users can expect out of the `threadblock` layer API going forward. + +## Compatible Device API + +The entry point for CUTLASS's Device GEMM API +is the class +`cutlass::gemm::device::GemmUniversalAdapter`. +This class lives in the header file +[include/cutlass/gemm/device/gemm_universal_adapter.h](/include/cutlass/gemm/device/gemm_universal_adapter.h). + +`GemmUniversalAdapter` is a "universal adapter" +and serves as a common device interface +for both CUTLASS 3.x and CUTLASS 2.x kernels. +Its template parameter `GemmKernel`, +the GEMM kernel type, can be any of the following: + +* `cutlass::gemm::kernel::GemmUniversal`, + implementing CUTLASS 3.x API kernels; +* `cutlass::gemm::kernel::GemmUniversal`, + implementing CUTLASS 2.x API kernels; +* Any valid CUTLASS 2.x `kernel` layer GEMM that + was previously composable with `device::GemmUniversalAdapter` + +Users implementing new kernels in either API should prefer +using `kernel::GemmUniversal` as the kernel type +and compose it with `device::GemmUniversalAdapter`. +Users with existing `kernel::Gemm` kernels +can continue to use them as template arguments +of `device::GemmUniversalAdapter`. They can adopt +`GemmUniversal` as a gradual migration path, +since `GemmUniversal` accepts either 3.0 or 2.x collectives. +Please see the [next section for `kernel::GemmUniversal`](#compatible-kernel-api) for details. + +`GemmUniversalAdapter` presents a single +host-side interface to both 3.0 and 2.x kernels. +CUTLASS accomplishes this by +specializing `GemmUniversalAdapter`'s implementation +on either 2.x API implementing kernel layer GEMMs, or 3.x API +implementing kernel layer GEMMs (as detected by `gemm::detail::IsCutlass3GemmKernel` +discussed below). As a result, `GemmUniversalAdapter`'s behavior +might differ between the two specializations. + +### Device API design differences + +In CUTLASS 2.x, the Device API was more closely tied +to the Kernel API. In CUTLASS 3.0, the Device API +accepts any kernel type that meets the Kernel API +interface requirements. CUTLASS 3.0's Device API code is +parameterized by the kernel type, but this code +is *generic*; the same code works for any kernel type. + +The device layer compatibility interface, `device::GemmUniversalAdapter`, +also provides reflective mappings from 3.0-specific types +back to the closest possible 2.x equivalent types. This is [discussed further in the section below](#conversions-between-2x-tags-and-30-types). + +CUTLASS 3.0's `device::GemmUniversalAdapter` also exposes some new APIs that the 2.x `device::GemmUniversalAdapter` implementation does not. Most notably, this includes the ability to bypass the `GemmKernel::Arguments` to `GemmKernel::Params` lowering. + +```c++ +// Primary run() entry point API that is static allowing users to create and manage their own params. +static Status +run(Params& params, cudaStream_t stream = nullptr); +``` + +This new API is useful for the following scenarios. + +* Running again does not require reinvoking `GemmKernel::to_underlying_arguments()` +* Manual control over construction of `GemmKernel::Params` for custom kernels with custom stride types +* Fully static problem shapes and strides for bespoke kernels where no argument mapping needs to take place + +## Compatible Kernel API + +CUTLASS 3.x API shares the kernel layer API with CUTLASS 2.x +through the single entry point type `cutlass::gemm::kernel::GemmUniversal`. +All kernel layer GEMMs are viewed as a composition of a collective mainloop +and a collective epilogue. + +**`kernel::GemmUniversal` implements both 2.x and 3.x APIs** + +The entry point for CUTLASS's kernel API is the class +`cutlass::gemm::kernel::GemmUniversal`. +This class' declaration lives in the header file +[include/cutlass/gemm/kernel/gemm_universal.hpp](/include/cutlass/gemm/kernel/gemm_universal.hpp). + +```c++ +/* + * Stateless universal device GEMM kernel type that treats GEMM as + * a composition of a collective mainloop and a collective epilogue. + * SFIANE shims both 2.x and 3.0 API kernels based on ProblemShapeOrThreadblockMma_. +**/ +template < + class ProblemShapeOrThreadblockMma_, + class CollectiveMainloopOrEpilogue_, + class CollectiveEpilogueOrThreadblockSwizzle_, + class GridSwizzle_ = void, + class Enable = void +> +class GemmUniversal; +``` + +We call this class "universal" because it can be built +using either the CUTLASS 3.0 or the 2.x mainloops and epilogues. +If `GemmUniversal`'s first template argument +(`ProblemShapeOrThreadblockMma_`) is a `cute::tuple`, +then `GemmUniversal` assumes that +the remaining three template arguments +(the mainloop, epilogue, and grid swizzle) +implement the 3.0 APIs. +Otherwise, `GemmUniversal` assumes that +the remaining three template arguments +implement the 2.x APIs. +All the template arguments must be either +CUTLASS 3.0 or CUTLASS 2.x types. For example, +`GemmUniversal` does not permit using +a 2.x mainloop with a 3.0 collective epilogue. + +CUTLASS 3.x implements various embodiments of `kernel::GemmUniversal`. +Each kernel layer schedule is specialized +for a GEMM scheduling algorithm and GPU architecture. +Specializations of `kernel::GemmUniversal` for 3.0 APIs live in +any of various `gemm_*.hpp` files in the directory +[include/cutlass/gemm/kernel/](../../include/cutlass/gemm/kernel/). +The specialization to which to dispatch is decided through the dispatch policy's `Schedule` type. + +Specializations for 2.x APIs live in the header file +[include/cutlass/gemm/kernel/gemm_universal.h](../../include/cutlass/gemm/kernel/gemm_universal.h). + +### Kernel API design differences + +The CUTLASS 2.x Kernel API was more closely tied +to the Device API, as we mentioned above. +In particular, the 2.x Device API specified the grid shape +used to launch the Kernel API. +In CUTLASS 3.0, the Kernel API controls its own grid shape, +while the device adapter simply queries the kernel with which it needs to be launched. + +This change is required to support various kernel schedules +that may need their own schedule specific grid planning logic. +For example, persistent kernel schedules generally only launch with +as many threadblocks as the number of multiprocessors on the GPU. + +All CUTLASS 3 `kernel::GemmUniversal` specializations expose the following (static) API: + +```c++ +// Returns true if the kernel can execute the provided GEMM arguments. +static bool +can_implement(Arguments const& args); + +// Returns a dim3 representing the threadblock shape. +static constexpr dim3 +get_block_shape(); + +// Returns a dim3 representing the grid shape in terms of threadblocks. +static constexpr dim3 +get_grid_shape(Params const& params); +``` + +The device adapter simply queries the kernel for these three before launching it on the device. +CUTLASS 3.0 provides a meta-function to detect whether a `cutlass::gemm::kernel::*` implements +the 3.x API or 2.x API: + +```c++ +// include/cutlass/gemm/gemm.h + +namespace cutlass:gemm::detail { + +// The following metafunction is used to detect whether a +// `kernel::Gemm` or `kernel::GemmUniversal` implements the CUTLASS 3.x API, +// by checking whether the problem shape type is aliased within. +template +struct IsCutlass3GemmKernel; + +} // namespace cutlass:gemm::detail +``` + +Users can dispatch their generic code against 2.x and 3.x specializations with +this as a type trait for the kernel API version. + +## Threadblock API and Inner Loops + +Much of the CUTLASS 3 GEMM hierarchy for mainloops and inner loops diverges +from that of CUTLASS 2.x. With that also comes the introduction of the +`cutlass::gemm::collective` layer as a direct replacement and a superset +of the 2.x `cutlass::gemm::threadblock` layer. Going forward, +CUTLASS 3.x will discontinue new developments in the following namespaces. + +* `cutlass::*::threadblock::*` +* `cutlass::*::warp::*` +* `cutlass::gemm::thread::*` +* `cutlass::arch::*` (except `barrier.h`) + +`cutlass::gemm::collective`s are a superset of the threadblock layer where +all new mainloops will be developed. Users should look to the `CollectiveMma` type +if they wish to author custom mainloop code in the 3.x API. + +Similarly, for the GEMM inner loops, `cute::MMA_Atom`s replace the +`gemm::warp` and `gemm::thread` layer code. Going forward, all new PTX instructions +and associated metadata development will occur directly inside [`cute/arch/*.hpp`](/include/cute/arch/) and [`cute/atom/*.hpp`](/include/cute/atom/). + +The desired inner loop MMA iteration order and tiling can be achieved through careful +selection of the atom layout, value layout, and permutations of the `cute::TiledMma`. + +For epilogues, the `cutlass::epilogue::collective` layer replaces `cutlass::threadblock::collective`. However, the thread-level epilogue elementwise operations +in `cutlass::epilogue::thread` will continue to be used in 3.x kernels as well, albeit, with +a more idiomatic epilogue vectorization strategy. +[Example 50](/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu) +shows how to use 2.x epilogue thread operators with 3.0 API kernels. + +## Porting from 2.x to 3.0 API + +### CUTLASS 2.x layout tags and CUTLASS 3.0 major modes + +CUTLASS 2.x and CUTLASS 3.0 use both +different wording and different types +to describe the permitted layouts +of GEMM's input matrices A and B. + +CUTLASS 3.0 does not use the terms "column major" +or "row major" to describe matrix layouts. +Starting with CUTLASS 3.0, adoption of CuTe allows us to decouple + +* the coordinate mode order (logical shape) of layouts from + +* the index space stride order of the backing storage. + +In line with our switch to a conceptual GEMM hierarchy, we view the major modes not from a BLAS-3 perspective. +Rather, we divide the modes into two categories. + +* "Inner modes" or "K-modes" are contracted over during the GEMM. + Therefore, they are not present in the output tensor. + +* "Outer modes" or "MN-modes" are preserved in the output. + +Now, instead of `RowMajor` or `ColumnMajor`, whose major stride depends on whether we are referring to the +A or the B matrix, we uniformly employ the "K major" or "MN major" terminology and enforce the convention of all tensors having the shape `[M/N, K, L]` regardless of which mode is major. That is, + +* the input matrix A has shape M x K, +* the input matrix B has shape N x K, and +* the input/output matrices C/D have shape M x N. + +Note that this convention for B +differs from the BLAS's GEMM interface, +which specifies that B has shape K x N. + +CUTLASS 3.0 uses these names of the modes +to specify which mode of a matrix has stride 1. +For the matrix A, + +* "M major" means that the matrix is stride 1 + in the M mode, and +* "K major" means that the matrix is stride 1 + in the K mode. + +For the matrix B, + +* "N major" means that the matrix is stride 1 + in the N mode (which for B is mode 0, + because the convention is that B is N x K); and +* "K major" means that the matrix is stride 1 + in the K mode (which for B is mode 1). + +CUTLASS 2.x defines "layout tag" classes +`cutlass::layout::ColumnMajor` and `cutlass::layout::RowMajor`, +that live in the header file +[`cutlass/layout/matrix.h`](/include/cutlass/layout/matrix.h). +The interpretation of these layouts in GEMM +depends on whether they are applied +to the input matrix A or B. For the matrix A, "column major" means +that mode corresponding to M extent has stride 1, +and "row major" means that mode corresponding to K extent has stride 1. +This is the usual computer science definition +of column major and row major for a rank-2 array. +For the matrix B, the opposite holds: +"column major" means that mode corresponding to N extent has stride 1, +and "row major" means that mode corresponding to K extent has stride 1. + +Using the convention of `[outer, inner, batch]` mode order for tensor logical shapes +avoids potential confusion with the meaning of column major and row major +changing depending on whether they are applied to A or B. + +The table below summarizes our mode order convention and +mapping of 2.x layout tags to corresponding M-major, N-major, or K-major strides. + +| Matrix | CUTLASS 2.x layout | 2.x Shape | Logical major mode| 3.x Shape/Stride | Major ordinal | +| --- | --- | --- | --- | --- | --- | +| A | `ColumnMajor` | M x K | M major | M x K x L | 0 (outer) | +| A | `RowMajor` | M x K | K major | N x K x L | 1 (inner) | +| B | `RowMajor` | K x N | N major | N x K x L | 0 (outer) | +| B | `ColumnMajor` | K x N | K major | N x K x L | 1 (inner) | +| C | `ColumnMajor` | M x N | M major | M x N x L | 0 (outer) | +| C | `RowMajor` | M x N | N major | M x N x L | 1 (inner) | + +Notice that in CUTLASS 3.0, interpretation of layouts no longer changes based on +whether we are talking about the A or B matrix. M and N major inputs always have a +static size-1 stride in their 0th (outer) mode. Similarly, K major inputs +always contain the static size-1 stride in their 1st mode. This uniformity in stride order +allows us to represent tensor layouts much more cleanly and treat both A and B equally in our interfaces. +See for example the following snippet from our [`kernel/sm70_gemm.hpp`](/include/cutlass/gemm/kernel/sm70_gemm.hpp) +for Ampere kernel schedules. + +```c++ +// Represent the full tensors +Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA); // (m,k,l) +Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB); // (n,k,l) + +// Get batch slice +Tensor mA_mk = mA_mkl(_,_,get<3>(blk_coord_mnkl)); // (m,k) +Tensor mB_nk = mB_nkl(_,_,get<3>(blk_coord_mnkl)); // (n,k) + +// Slice to get the tiles for which this thread block is responsible +Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) +Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) +``` + +As seem in this snippet, all input tensors have the logical shape `[outer, inner, batch]`, +and the strides could represent either outer or inner +(or any other complex hierarchical stride) major storage. +CuTe layouts always maintain the logical consistency of the coordinate spaces regardless of the strides. + +By convention, in CUTLASS 3.0, we treat the M and N mode as the 0th mode, +and K mode as the 1st mode of the stride. + +### Conversions between 2.x tags and 3.0 types + +Starting with CUTLASS 3.0, all layouts are described using +`cute::Shape` and `cute::Stride` which compose into a `cute::Layout`. +In CUTLASS 2.x, various layout tags such as `cutlass::layout::RowMajor` are used to specialize +template implementations. These tag types only encode information about the tensor strides, +as 2.x layouts did not incorporate any concept of tensor shape in the layout tags themselves. +Users may find a need to convert between CUTLASS 2.x layout tags, and 3.0 +CuTe stride types. CUTLASS 3.0 `gemm::collective::CollectiveBuilder` interfaces +also accept these 2.x layout tags as input parameters in their template API as a convenience for users. +At every entry point into CUTLASS 3.0, these tags get converted to their corresponding CuTe Stride type with +metafunctions that best approximate their corresponding `cute::Stride`. + +* `cutlass::gemm::detail::TagToStrideA_t` +* `cutlass::gemm::detail::TagToStrideB_t` +* `cutlass::gemm::detail::TagToStrideC_t` + +By convention, and to match user expectations, the `cute::Stride` types that these +map onto always contain one static mode corresponding to the layout tag, and two 64-bit +dynamic stride modes corresponding to the minor mode and the batch mode. Batch +mode is included by default as all CUTLASS 3.0 kernels support packed batch-mode GEMMs +out of the box. + +The [`cutlass/gemm/gemm.h#440`](../../include/cutlass/gemm/gemm.h#440) +header file includes functions +that can be useful for converting +from CUTLASS 3.0 `cute::Stride`s back to CUTLASS 2.x layout tags. + +* `cutlass::gemm::detail::StrideToLayoutTagA_t` +* `cutlass::gemm::detail::StrideToLayoutTagB_t` +* `cutlass::gemm::detail::StrideToLayoutTagC_t` + +These metafunctions take the CuTe Stride as a template parameter and +attempt to find the size-1 stride in the idiomatic M, N, or K modes +to best approximate a corresponding 2.x layout tag type. +Note that this may not work in general for any `cute::Stride` +as the mapping between the stride and tag type is not bijective. + +These mapping utilities are kept in a `detail` namespace +as we do not guarantee stability of their implementation. +Their behavior may change in future releases as we add new features. +However, we do expect these type names to remain stable. For users who want +these 2.x reflective types from an assembled kernel with a more stable API, +the specialization of `cutlass::gemm::device::GemmUniversalAdapter` +for CUTLASS 3.0 kernel provides all aliases for all 2.x type aliases +in addition to the layout tags. You can see how they are used in the header file +[`cutlass/gemm/device/gemm_universal_adapter.h`](/include/cutlass/gemm/device/gemm_universal_adapter.h). +Here is an excerpt. + +```c++ + // Map back to 2.x type as best as possible + using LayoutA = gemm::detail::StrideToLayoutTagA_t; + using LayoutB = gemm::detail::StrideToLayoutTagB_t; + using LayoutC = gemm::detail::StrideToLayoutTagC_t; + using LayoutD = gemm::detail::StrideToLayoutTagC_t; + + // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 + using MathOperator = cutlass::arch::OpMultiplyAdd; + + // If our TiledMMA's instruction thread layout size is larger than 1, + // we know it's a tensorop + using OperatorClass = std::conditional_t< + (cute::size(typename GemmKernel::TiledMma::AtomThrID{}) > 1), + cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>; + + // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape + using ThreadblockShape = cutlass::gemm::GemmShape< + cute::size<0>(TileShape{}), + cute::size<1>(TileShape{}), + cute::size<2>(TileShape{})>; + + using ClusterShape = cutlass::gemm::GemmShape< + cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; + + // We get the instruction shape directly from our TiledMma's atom shape + using InstructionShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; + + static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; + static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; + + // Warp shape is not a primary API type in 3.x, + // but we can best approximate it by inspecting the TiledMma::TiledShape_MNK. + // For this, we make the assumption that we always have 4 warps along M, + // and the rest along N, with none along K. We also always round up + // the warp count to 4 if the tiled mma is smaller than 128 threads. + static constexpr int WarpsInMma = std::max(4, cute::size(typename GemmKernel::TiledMma{}) / 32); + static constexpr int WarpsInMmaM = 4; + static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); + using WarpCount = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaM, + cute::size<1>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaN, + cute::size<2>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{})>; + + // Inspect TiledCopy for A and B to compute the alignment size + static int constexpr kAlignmentA = gemm::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, ElementA>(); + static int constexpr kAlignmentB = gemm::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, ElementB>(); +``` + +CUTLASS's library and profiler use these reflective interfaces to +obtain the kernel's configuration parameters. Users can use these to approximate the CUTLASS 2.x types +for 3.0 API kernels. However, the reflective interfaces cannot always match the types exactly, +as the mappings are not always bijective. + +# Copyright + +Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/media/docs/cutlass_3x_design.md b/media/docs/cutlass_3x_design.md new file mode 100644 index 0000000000..9db3359d26 --- /dev/null +++ b/media/docs/cutlass_3x_design.md @@ -0,0 +1,117 @@ +[README](/README.md#documentation) > **CUTLASS 3.0 Design and Hierarchy** + +# CUTLASS 3.0 Design + +CUTLASS 3.0 is a major enhancement over the abstractions of CUTLASS 2.x +and aims to make usage of all layers of the GEMM hierarchy easier and more composable +while still achieving peak performance on Hardware. + +## CUTLASS 3.0 design goals + +CUTLASS 3.0 has the following design goals, in no particular order. + +- Simplify expressing and manipulating data and thread layouts across + the GEMM hierarchy with CuTe layouts and layout algebra. + +- Improve code readability and learning curve by + reducing the number of named types. + +- Functional correctness by default, + actionable static asserts otherwise. + +- Single, clear points of performance tuning and custom kernel extensions. + +- Support for NVIDIA Hopper GPUs with great performance using + features such as Tensor Cores, tensor memory accelerator, and thread block clusters. + +## A new Conceptual GEMM Hierarchy + +CUTLASS 2.x decomposes the moving parts of a GEMM operation +across a hierarchy that closely mirrors the organization of GPU +architectures. This discussed in detail within the +[CUTLASS 2.x GEMM API documentation](/media/docs/gemm_api.md). +This design, however, sometimes results in a coupling that is too tight +to extend to newer GPU features that might not fit into the same architectural +hierarchy. For instance, Hopper's warp-group wide instructions do not naturally +fit into any warp or thread layer GEMM concept in CUTLASS 2.x. Even for Volta tensor cores, +instructions that atomically exist at the quad-pair granularity are first tiled at +the warp level before use. This hints at the brittleness of the abstraction power. + +CUTLASS 3.0 detaches its interface layers from the hardware, +centering them instead around the natural structure of GEMM algorithms +not tied to any particular GPU generation. +This makes CUTLASS's code more robust to GPU architecture evolution, +less prone to implementation detail leakage, and provides users +with a consistent interface to hardware acceleration regardless of +the architecture specific details. + +The new conceptual GEMM hierarchy is discussed in detail in the dedicated +[CUTLASS 3.0 GEMM API documentation readme](/media/docs/gemm_api_3x.md), +along with code examples of the core concepts and types. + +## Adoption of CuTe Layout and Tensors + +CUTLASS 3.0 introduces a new core library, CuTe, to describe and manipulate tensors of threads and data. +CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly packages the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. + +CUTLASS 3.0 adopts CuTe throughout the GEMM hierarchy in its templates, greatly simplifying the design, +improving code composability, and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](/media/docs/cute/00_quickstart.md). + +![CuTe helps reduce named iterator types down to a single vocabulary type, `Layout`](/media/images/cutlass-reduction-in-named-iterators.png) + +Programming massively parallel systems with various layers of logical thread and data hierarchies is not a trivial task. + +- `cute::Layout`s always maintain logical consistency of their coordinates, + allowing us to check pre- and post-conditions at compile time for all static inner loops. +- Explicit thread to data mapping allows users and kernel authors to inspect and reason about operations + from a single point in the source code. +- Layouts provide a single point of performance tuning, as most optimizations can be done by careful + selection of thread and data layouts. +- Formalized algebra makes manipulation of and reasoning about thread->data mapping explicit in source code. +- Single vocabulary type (`cute::Layout`) subsumes every iterator and layout in CUTLASS 2.x CUTLASS 2.x uses many bespoke thread maps, iterators, and data layouts. Iterators are fundamentally 1-D, whereas most layouts we encounter in the GPU hierarchy are fundamentally n-D. + +## Reducing the number of named types and iterator concepts + +CUTLASS 2.x design preferred introducing bespoke named types for each +architecture specific thread and data layout. For instance, `gemm::treadblock` namespace +contains implementation for `MmaMultistage`, `MmaPlanarComplexMultistage`, `MmaPipelined` etc. +despite them providing mainloops for GEMMs. To spell these types the same way in generic code, +CUTLASS 2.x provides aliases through its `default_x_configuration.h` files, however, +these aliases make the code much harder to read as the user has to perform type substitution +mentally in order to understand the codebase. + +CUTLASS 3.0 greatly reduces the number of named types used throughout by + +- Replacing all iterator concepts for all memory domains with `cute::Tensor`s +- Dispatching mainloop and epilogue implementations on tag-dispatch policies rather than naming new types +- Dispatching kernel layer schedules on tag-dispatch policies rather than naming new types + +Reducing the number of named types has many benefits: + +- It *makes writing generic code easier*, as the primary type names share the same lexical + without aliasing through configuration providers. +- It *flattens the learning curve of CUTLASS* by greatly reducing the mental context required + as the library only exposes a handful of named types. +- It *provides a clear, singular extension point* for users to plug in their customizations + through the dispatch policies. + +## Correctness by default, Performance through clear, individual points of tuning + +CUTLASS 2.x maintained its thread layouts as implicit indexing math implemented +as a part of 1D iterators. This meant that the thread to data layout mapping +was implicit in the imperative structure of the C++ code itself and did not have +a formal algebra we could use to manipulate these mappings. Each iterator +had to re-implement its indexing and mapping logic. This made it hard to learn +how this mapping was performed for existing iterators, and even harder to +implement custom layout functions for the core inner loops of a GEMM. + +CUTLASS 3.0 replaces all iterator concepts from CUTLASS 2.x +with a single layout type for thread and data tensors. +CuTe's formalized layout algebra is then used at every layer of +the GEMM hierarchy to manipulate the mapping between the two. +CuTe layouts always maintain logical consistency, and for fully static layouts +(such as in the core unrolled inner loops), provide +compile time checks that break builds if this consistency is violated. +In this way, CuTe reifies the thread-to-data-layout mapping, +makes it easier to write code that is "correct by construction". +If the code compiles, it's probably correct. diff --git a/media/docs/doxygen_mainpage.md b/media/docs/doxygen_mainpage.md index 1cb5a56b07..4145748164 100644 --- a/media/docs/doxygen_mainpage.md +++ b/media/docs/doxygen_mainpage.md @@ -1,14 +1,14 @@ -# CUTLASS 2.0 +# CUTLASS 3.0 -_CUTLASS 2.0 - November 2019_ +_CUTLASS 3.0 - January 2023_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA. It incorporates strategies for hierarchical decomposition and data movement similar to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into reusable, modular software components abstracted by C++ template classes. These -thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized -and tuned via custom tiling sizes, data types, and other algorithmic policy. The +components can be specialized +and tuned via custom tiling sizes, data types, and other algorithmic policies. The resulting flexibility simplifies their use as building blocks within custom kernels and applications. @@ -16,107 +16,25 @@ To support a wide variety of applications, CUTLASS provides extensive support fo mixed-precision computations, providing specialized data-movement and multiply-accumulate abstractions for 8-bit integer, half-precision floating point (FP16), single-precision floating point (FP32), and double-precision floating -point (FP64) types. Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply -operations for targeting the programmable, high-throughput _Tensor Cores_ implemented -by NVIDIA's Volta and Turing architectures. +point (FP64) types. Furthermore, CUTLASS exploits the _Tensor Cores_ and asynchronous +memory copy operations of the latest NVIDIA GPU architectures. +# What's New in CUTLASS 3.0 -# What's New in CUTLASS 2.0 +For an overview of CUTLASS 3.0's GEMM interface levels, +please refer to the +[CUTLASS 3.0 GEMM API document](./gemm_api_3x.md). +To learn how to migrate code using CUTLASS 2.x's interface +to CUTLASS 3.0, please refer to the +[backwards compatibility document](./cutlass_3x_backwards_compatibility.md). -CUTLASS 2.0 is a substantial refactoring from the previous version, intended to offer: +# GEMM examples -- Better performance over 1.x, particularly for kernels targeting Turing Tensor Cores -- Robust and durable templates that reliably span the design space -- Encapsulated functionality that may be reusable in other contexts - - -# Example CUTLASS GEMM - -The following illustrates an example function that defines a CUTLASS GEMM kernel -with single-precision inputs and outputs. This is an excerpt from the CUTLASS SDK -[basic_gemm example](https://github.com/NVIDIA/cutlass/tree/master/examples/00_basic_gemm/basic_gemm.cu). - -~~~~~~~~~~~~~~~~~~~~~{.cpp} -// -// CUTLASS includes needed for single-precision GEMM kernel -// - -// Defines cutlass::gemm::device::Gemm, the generic Gemm computation template class. - -#include - -/// Define a CUTLASS GEMM template and launch a GEMM kernel. -cudaError_t cutlass_sgemm_nn( - int M, - int N, - int K, - float alpha, - float const *A, - int lda, - float const *B, - int ldb, - float beta, - float *C, - int ldc) { - - // Define type definition for single-precision CUTLASS GEMM with column-major - // input matrices and 128x128x8 threadblock tile size (chosen by default). - // - // To keep the interface manageable, several helpers are defined for plausible compositions - // including the following example for single-precision GEMM. Typical values are used as - // default template arguments. See `cutlass/gemm/device/default_gemm_configuration.h` for more details. - // - // To view the full gemm device API interface, see `cutlass/gemm/device/gemm.h` - - using ColumnMajor = cutlass::layout::ColumnMajor; - - using CutlassGemm = cutlass::gemm::device::Gemm; // Layout of C matrix - - // Define a CUTLASS GEMM type - - CutlassGemm gemm_operator; - - // Construct the CUTLASS GEMM arguments object. - // - // One of CUTLASS's design patterns is to define gemm argument objects that are constructible - // in host code and passed to kernels by value. These may include pointers, strides, scalars, - // and other arguments needed by Gemm and its components. - // - // The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible - // arguments to kernels and (2.) minimized initialization overhead on kernel entry. - // - - CutlassGemm::Arguments args({M , N, K}, // Gemm Problem dimensions - {A, lda}, // Tensor-ref for source matrix A - {B, ldb}, // Tensor-ref for source matrix B - {C, ldc}, // Tensor-ref for source matrix C - {C, ldc}, // Tensor-ref for destination matrix D (may be different memory than source C matrix) - {alpha, beta}); // Scalars used in the Epilogue - - // - // Launch the CUTLASS GEMM kernel. - // - - cutlass::Status status = gemm_operator(args); - - // - // Return a cudaError_t if the CUTLASS GEMM operator returned an error code. - // - - if (status != cutlass::Status::kSuccess) { - return cudaErrorUnknown; - } - - // Return success, if no errors were encountered. - - return cudaSuccess; -} -~~~~~~~~~~~~~~~~~~~~~ +For a code example showing how to define +a GEMM kernel using CUTLASS, please refer to +[the quickstart guide](./quickstart.md). +The [`examples` directory](../../examples) +has a variety of examples. # Copyright diff --git a/media/docs/efficient_gemm.md b/media/docs/efficient_gemm.md index 359d5794a3..533ebc85df 100644 --- a/media/docs/efficient_gemm.md +++ b/media/docs/efficient_gemm.md @@ -219,6 +219,21 @@ which has to happen at the end among the participating warps. This is because each warp computes using only a "slice" of CtaTileK, so each warp only has a partial sum before the reduction. +### Warp Specialization + +Starting with Hopper, CUTLASS 3.0 incorporates the concept of [Warp Specialization](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#spatial-partitioning-also-known-as-warp-specialization) +as part of the kernel design. A thread block is partitioned into two sets of warps, [*producer* warp group](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [*consumer* warp group](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp). The *producer* warp group loads data from global memory into shared memory buffers using the new [Tensor Memory Accelerator (TMA)](https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/). + +[*Producer* warp group (DMA)](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) waits for the shared memory buffers to be signaled as [empty](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) by the *consumer* warp group using the newly added **Async Pipeline class** ([refer](/media/docs/pipeline.md)). Once the data is written into the shared memory, TMA is also updates the barrier associated with that stage to notify affected threads that the buffer has been [filled](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp). The [*Consumer* warp group (MMA)](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) on the other hand waits for the *producer* warp group to signal that the buffer is [filled](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) and then launches tensor core MMA operations. Finally, the *consumer* warp group [releases](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) the buffers for the next set of TMA loads to happens. + +**Warp-Specialized Persistent kernel design** + +Another flavor of Warp Specialized kernel design being introduced starting with Hopper is the [*Warp-Specialized Persistent*](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp) kernel. Like Warp Specialized kernel the concepts of warp groups and barrier synchronization between warp groups remain the same in the persistent design. The distinctive feature of the Warp-Specialized Persistent kernel are the following : +* Persistent thread blocks launched to occupy as many SMs as mentioned in the [KernelHardwareInfo](include/cutlass/kernel_hardware_info.hpp) struct. These persistent thread blocks are used to tile the output and thus (potentially) compute multiple output tiles through their lifetime. The main benefit this adds is amortization of the thread-block launch and kernel prologue overheads which are typical of all kernels. +* Presence of one two *consumer* warp groups which allows for *epilogue* of one *consumer* warp group to be overlapped with the math operations of the other *consumer* warp group - thus maximizing tensor core utilization. + +Each *consumer* warp group is assigned a different output tile. The *producer* warp group synchronizes using the [Ordered Sequence Barrier](/include/cutlass/pipeline.hpp) to fill buffers of the two *consumer* warp groups one after the other in order. Since each thread block now computes multiple output tiles, the shape of the grid launch and the scheduling of tiles to the thread blocks is managed using the new [*Tile Scheduler*](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp). The *Tile Scheduler* considers the shape of the *clusters* as well as the available number of available SMs to compute a valid scheduling of the output tiles to launched thread blocks. + # Resources The following additional resources describe design and implementation details of GEMMs diff --git a/media/docs/functionality.md b/media/docs/functionality.md index 71bc9b0925..fea258f4ab 100644 --- a/media/docs/functionality.md +++ b/media/docs/functionality.md @@ -4,12 +4,15 @@ # Functionality +Note : CUTLASS-3 requires users to use CUDA 11.4 or newer, and SM70 or newer, for the target toolkit and architecture, respectively. +Please refer to the [Compatibility](/README.md#Compatibility) section for more details. + - N - Column Major Matrix - T - Row Major matrix -- {N,T} x {N,T} - All combinations, i.e. NN, NT, TN, TT +- {N,T} x {N,T} - All combinations, i.e., NN, NT, TN, TT - [NHWC](/include/cutlass/layout/tensor.h#L63-206) - 4 dimension tensor used for convolution - [NCxHWx](/include/cutlass/layout/tensor.h#L290-395) - Interleaved 4 dimension tensor used for convolution -- f - float point +- f - floating point - s - signed int - b - bit - cf - complex float @@ -22,42 +25,55 @@ ## Device-level GEMM -The following table summarizes device-level GEMM kernels in CUTLASS, organized by opcode class, data type, and layout. +The following tables summarize device-level GEMM kernels in CUTLASS, organized by opcode class, data type, and layout. Hyperlinks to relevant unit tests demonstrate how specific template instances may be defined. +### CUTLASS 3.x Kernels + +|**Opcode Class** | **Compute Capability** | **CUDA Toolkit** | **Data Type** | **Layouts** | **Unit Test** | +|-----------------|------------------------|------------------|--------------------------------|------------------------|------------------| +| **TensorOp** | 90a | 12.0+ | `f16 * f16 + { f16, f32 } => { f16, f32 }` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu) | +| **TensorOp** | 90a | 12.0+ | `bf16 * bf16 + { f16, f32 } => { bf16, f32 }`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu) | +| **TensorOp** | 90a | 12.0+ | `{f32, tf32} * {f32, tf32} + f32 => f32`| { T } x { N } => {N,T} | [example](/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu) | +| **TensorOp** | 90a | 12.0+ | `s8 * s8 + s32 => {s32, s8}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu) | + + +### CUTLASS 2.x Kernels + |**Opcode Class** | **Compute Capability** | **CUDA Toolkit** | **Data Type** | **Layouts** | **Unit Test** | |-----------------|------------------------|------------------|--------------------------------|------------------------|------------------| -| **Simt** | 50,60,61,70,75 | 9.2+ | `f32 * f32 + f32 => f32` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/simt_sgemm_nt_sm50.cu) | -| **Simt** | 50,60,61,70,75 | 9.2+ | `f64 * f64 + f64 => f64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/simt_dgemm_nt_sm50.cu) | -| **Simt** | 60,61,70,75 | 9.2+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/simt_hgemm_nt_sm50.cu) | -| **Simt** | 61,70,75 | 9.2+ | `s8 * s8 + s32 => {s32,s8}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/simt_igemm_nt_sm50.cu) | -| **WmmaTensorOp** | 70 | 9.2+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f16_sm70.cu) | -| **WmmaTensorOp** | 70 | 9.2+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f32_sm70.cu) | -| **WmmaTensorOp** | 75 | 10.0+ | `s8 * s8 + s32 => {s32, s8}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu) | -| **WmmaTensorOp** | 75 | 10.0+ | `s4 * s4 + s32 => {s32, s4}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s4t_wmma_tensor_op_s32_sm75.cu) | -| **WmmaTensorOp** | 75 | 10.0+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_b1t_b1n_b1t_wmma_tensor_op_s32_sm75.cu) | -| **TensorOp** | 70 | 10.1+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f16_sm70.cu) | -| **TensorOp** | 70 | 10.1+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f32_sm70.cu) | -| **TensorOp** | 75 | 10.2+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu) | -| **TensorOp** | 75 | 10.2+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm75.cu) | -| **TensorOp** | 75 | 10.2+ | `s8 * s8 + s32 => {s32, s8}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu) | -| **TensorOp** | 75 | 10.2+ | `s4 * s4 + s32 => {s32, s4}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu) | -| **TensorOp** | 75 | 10.2+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu) | -| **TensorOp** | 80 | 11.0+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu) | -| **TensorOp** | 80 | 11.0+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu) | -| **TensorOp** | 80 | 11.0+ | `bf16 * bf16 + f32 => {bf16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_bf16n_bf16t_bf16t_tensor_op_f32_sm80.cu) | -| **TensorOp** | 80 | 11.0+ | `tf32 * tf32 + f32 => f32`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sm80.cu) | -| **TensorOp** | 80 | 11.0+ | `s8 * s8 + s32 => {s32, s8}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu) | -| **TensorOp** | 80 | 11.0+ | `s4 * s4 + s32 => {s32, s4}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu) | -| **TensorOp** | 80 | 11.0+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu) | -| **TensorOp** | 80 | 11.0+ | `f64 * f64 + f64 => f64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu) | -| **TensorOp** | 80 | 11.0+ | `cf32 * cf32 + cf32 => cf32` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu) | -| **TensorOp** | 80 | 11.0+ | `cf64 * cf64 + cf64 => cf64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu), [Gaussian 3m](/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu) | -| **SpTensorOp** | 80 | 11.1+ | `f16 * f16 + f32 => {f16, f32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) | -| **SpTensorOp** | 80 | 11.1+ | `bf16 * bf16 + f32 => {bf16, f32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) | -| **SpTensorOp** | 80 | 11.1+ | `tf32 * tf32 + f32 => f32` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu) | -| **SpTensorOp** | 80 | 11.1+ | `s8 * s8 + s32 => {s8, s32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu) | -| **SpTensorOp** | 80 | 11.1+ | `s4 * s4 + s32 => {s4, s32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu) | +| **Simt** | 50+ | 11.4+ | `f32 * f32 + f32 => f32` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/simt_sgemm_nt_sm50.cu) | +| **Simt** | 50+ | 11.4+ | `f64 * f64 + f64 => f64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/simt_dgemm_nt_sm50.cu) | +| **Simt** | 60+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/simt_hgemm_nt_sm50.cu) | +| **Simt** | 61+ | 11.4+ | `s8 * s8 + s32 => {s32,s8}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/simt_igemm_nt_sm50.cu) | +| **WmmaTensorOp** | 70+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f16_sm70.cu) | +| **WmmaTensorOp** | 70+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f32_sm70.cu) | +| **WmmaTensorOp** | 75+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu) | +| **WmmaTensorOp** | 75+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s4t_wmma_tensor_op_s32_sm75.cu) | +| **WmmaTensorOp** | 75+ | 11.4+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_b1t_b1n_b1t_wmma_tensor_op_s32_sm75.cu) | +| **TensorOp** | 70+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f16_sm70.cu) | +| **TensorOp** | 70+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f32_sm70.cu) | +| **TensorOp** | 75+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu) | +| **TensorOp** | 75+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm75.cu) | +| **TensorOp** | 75+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu) | +| **TensorOp** | 75+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu) | +| **TensorOp** | 75+ | 11.4+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu) | +| **TensorOp** | 80+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `bf16 * bf16 + f32 => {bf16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_bf16n_bf16t_bf16t_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `tf32 * tf32 + f32 => f32`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `f64 * f64 + f64 => f64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `cf32 * cf32 + cf32 => cf32` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `cf64 * cf64 + cf64 => cf64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu), [Gaussian 3m](/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu) | +| **SpTensorOp** | 80+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) | +| **SpTensorOp** | 80+ | 11.4+ | `bf16 * bf16 + f32 => {bf16, f32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) | +| **SpTensorOp** | 80+ | 11.4+ | `tf32 * tf32 + f32 => f32` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu) | +| **SpTensorOp** | 80+ | 11.4+ | `s8 * s8 + s32 => {s8, s32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu) | +| **SpTensorOp** | 80+ | 11.4+ | `s4 * s4 + s32 => {s4, s32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu) | +| **TensorOp** | 90+ | 11.8+ | `f64 * f64 + f64 => f64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) | ## Device-level Implicit GEMM convolution @@ -68,19 +84,19 @@ One can find and/or create equivalent dgrad and wgrad convolutional operators. |**Opcode Class** | **Compute Capability** | **CUDA Toolkit** | **Data Type** | **Layouts** | **Unit Test** | |-----------------|------------------------|------------------|--------------------------------|------------------|------------------| -| **Simt** | 50,60,61,70,75 | 9.2+ | `f32 * f32 + f32 => f32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm50.cu) | -| **Simt** | 50,60,61,70,75 | 9.2+ | `cf32 * cf32 + cf32 => cf32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu) | -| **TensorOp** | 70 | 10.1+ | `f16 * f16 + f32 => {f16, f32}`| NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu) | -| **TensorOp** | 75 | 10.2+ | `f16 * f16 + f32 => {f16, f32}`| NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu) | -| **TensorOp** | 75 | 10.2+ | `s8 * s8 + s32 => {s32, s8}` | NHWC, NCxHWx | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm75.cu), [ncxhwx](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.cu) | -| **TensorOp** | 75 | 10.2+ | `s4 * s4 + s32 => {s32, s4}` | NHWC, NCxHWx | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu), [ncxhwx](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm75.cu) | -| **Simt** | 80 | 11.0+ | `f32 * f32 + f32 => f32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu) | -| **Simt** | 80 | 11.0+ | `cf32 * cf32 + cf32 => cf32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu) | -| **TensorOp** | 80 | 11.0+ | `f16 * f16 + f32 => {f16, f32}`| NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) | -| **TensorOp** | 80 | 11.0+ | `f16 * f16 + f16 => f16` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) | -| **TensorOp** | 80 | 11.0+ | `tf32 * tf32 + f32 => f32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu) | -| **TensorOp** | 80 | 11.0+ | `s8 * s8 + s32 => {s32, s8}` | NHWC, NCxHWx | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm80.cu), [ncxhwx](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.cu) | -| **TensorOp** | 80 | 11.0+ | `s4 * s4 + s32 => {s32, s4}` | NHWC, NCxHWx | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm80.cu), [ncxhwx](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm80.cu) | +| **Simt** | 50+ | 11.4+ | `f32 * f32 + f32 => f32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm50.cu) | +| **Simt** | 50+ | 11.4+ | `cf32 * cf32 + cf32 => cf32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu) | +| **TensorOp** | 70+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu) | +| **TensorOp** | 75+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu) | +| **TensorOp** | 75+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | NHWC, NCxHWx | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm75.cu), [ncxhwx](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.cu) | +| **TensorOp** | 75+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | NHWC, NCxHWx | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu), [ncxhwx](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm75.cu) | +| **Simt** | 80+ | 11.4+ | `f32 * f32 + f32 => f32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu) | +| **Simt** | 80+ | 11.4+ | `cf32 * cf32 + cf32 => cf32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `f16 * f16 + f16 => f16` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `tf32 * tf32 + f32 => f32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | NHWC, NCxHWx | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm80.cu), [ncxhwx](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | NHWC, NCxHWx | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm80.cu), [ncxhwx](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm80.cu) | diff --git a/media/docs/gemm_api_3x.md b/media/docs/gemm_api_3x.md new file mode 100644 index 0000000000..c4a454896e --- /dev/null +++ b/media/docs/gemm_api_3x.md @@ -0,0 +1,701 @@ +![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS GEMM API") + +[README](/README.md#documentation) > **CUTLASS 3.0 GEMM API** + +# CUTLASS 3.0 GEMM API + +CUTLASS presents a uniform programming model +for matrix multiply-accumulate (MMA) operations +at different levels of the GPU system hierarchy. +CUTLASS 3.0 has GEMM APIs corresponding to the following levels +in order of highest to the lowest level. + +1. Device +2. Kernel +3. Collective +4. Tiled MMA and Copy +5. Atom + +This document will cover the first three levels in detail: +Device, Kernel, and Collective. +It also briefly discusses the Tiled MMA/Copy and Atom level, +and then refers readers to CuTe's tutorial for more information. + +# CUTLASS GEMM Model + +CUTLASS implements algorithms that express +the classical "triply nested loop" GEMM algorithm +with a tiled structure mirroring the above hierarchy. + +The following pseudocode describes the model for a GEMM kernel +targeting a warp-synchronous matrix multiply instruction like `mma.sync.` +The entire operation is referred to as "Gemm," +as it is assumed that an epilogue operation +performs the general matrix update similar to BLAS. +This is pseudocode and is only meant to illustrate which parts of the layers +correspond to the inner or outer loops of the GEMM. + +```c++ +// cutlass::gemm::kernel::GemmUniversal: ClusterTileM and ClusterTileN loops +// are either rasterized by the hardware or scheduled by the kernel in persistent kernels. +// Parallelism over thread block clusters +for (int cluster_m = 0; cluster_m < GemmM; cluster_m += ClusterTileM) { + for (int cluster_n = 0; cluster_n < GemmN; cluster_n += ClusterTileN) { + + // cutlass::gemm::collective::CollectiveMma: mainloop that iterates over all k-tiles + // No loop unrolling is performed at this stage + for (int k_tile = 0; k_tile < size<2>(gmem_tensor_A); k_tile++) { + + // loops inside cute::gemm(tiled_mma, a, b, c); Dispatch 5: (V,M,K) x (V,N,K) => (V,M,N) + // TiledMma uses the hardware instruction provided through its Mma_Atom + // TiledMma's atom layout, value layout, and permutations define the iteration order + for (int tiled_mma_k = 0; tiled_mma_k < size<2>(A); tiled_mma_k++) { + for (int tiled_mma_m = 0; tiled_mma_m < size<1>(A); tiled_mma_m++) { + for (int tiled_mma_n = 0; tiled_mma_n < size<1>(B); tiled_mma_n++) { + + // TiledMma's vector mode dispatches to the underlying instruction. + mma.call(d, a, b, c); + } // tiled_mma_n + } // tiled_mma_m + } // tiled_mma_k + } // k_tile mainloop + } // cluster_m +} // cluster_n +``` + +The first three nested `for` loops +correspond to parallelism over thread block clusters. +The code does not actually express them as explicit `for` loops. +Instead, the parallelization scheme over tiles +is implied by CUDA grid launch semantics. +However, for persistent kernels, +these three loops are expressed in the source code +as a single `while` loop that queries the +[work tile scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp) +for problem tiles on which to compute. + +Inside the three nested `for` loops, +one finds code that pulls matrix tiles +from global memory into more "local" memory +(like shared memory or registers) +and computes MMAs. +These tiled copy and tiled mma iterations are generally +fully static and get fully unrolled. + +# CUTLASS GEMM Components + +CUTLASS expresses the above loop nest +with the following components which are specialized for +data type, layout, and math instruction. + +| API level | API Class and/or function names | +| --- | --- | +| Device | `cutlass::gemm::device::GemmUniversalAdapter` | +| Kernel | `cutlass::gemm::kernel::GemmUniversal` | +| Collective | `cutlass::gemm::collective::CollectiveMma`
`cutlass::epilogue::collective::DefaultEpilogue`
`cutlass::epilogue::collective::Epilogue`
| +| Tiled (MMA and Copy) | `cute::TiledMma` and `cute::TiledCopy`
`cute::gemm()` and `cute::copy()` | +| Atom | `cute::Mma_Atom` and `cute::Copy_Atom` | + +In CUTLASS 3.0, we assemble kernels +by first composing a collective mainloop and collective epilogue +together at the kernel layer, +and then wrapping them with a host-side adapter +to form a GEMM handle to that kernel. + +The following sections describe these components +in the order a user should instantiate them +in order to assemble a kernel. This order is + +1. assemble the required collective mainloop and epilogues, + +2. compose them together to build a kernel type, and + +3. wrap up the kernel with a device layer adapter. + +This order is also reflected in the [CUTLASS 3.0 Hopper kernel examples](/examples/48_hopper_warp_specialized_gemm) as seen in the excerpt below. + +```c++ +// Step 1: Generate the required collective layer mainloop specialization +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TilesShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + +// Step 2: Specify the collective layer epilogue type +using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + +// Step 3: Compose the mainloop and epilogue together at the kernel layer +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, // ProblemShape [M,N,K,L] + CollectiveMainloop, + CollectiveEpilogue +>; + +// Step 4: Wrap up the kernel::GemmUniversal kernel class +// with the device adapter to obtain a host-side handle to the kernel +using GemmHandle = cutlass::gemm::device::GemmUniversalAdapter; +``` + +Towards the end, we also briefly cover CuTe's tiled mma and copy as well as the atom layer APIs, +before redirecting users to CuTe-specific documentation for further details. + +## Collective API + +A Collective is "the largest collection of threads +onto which mma atoms and copy atoms are tiled." +That is, it is the largest number of threads in a grid +that can cooperate by leveraging hardware features +for accelerated communication and synchronization. +These hardware features include + +* asynchronous array copy + (e.g., from global memory to shared memory); + +* MMA instructions + for small tiles that live in shared memory; + +* synchronization operations for clusters, + thread blocks, and/or warps; and/or + +* hardware acceleration (such as barriers) + for ensuring that data dependencies + between asynchronous operations are met. + +A Collective uses the `TiledMma` and `TiledCopy` API (see below) +to access operations that copy and perform MMA on tiles. + +Different units of parallelism +(e.g., threads, warps, or thread blocks) +in a Collective might have different roles. +For example, in "warp-specialized" algorithms, +some warps may be responsible for copying data, +while others may be responsible for computation. +Nevertheless, the different units of parallelism +still need to share data and coordinate access +to the shared data. For example, +the producer warps in a warp-specialized algorithm +that copy input matrix tiles into shared memory +need to let the consumer MMA warp(s) know +that their MMA inputs are ready. +We contrast this with the `kernel::` layer API, +which schedules the collectives over *independent* tiles in the grid. + +The Collective API includes both the "mainloop" +of matrix multiply-accumulate, and the epilogue. +This API is the composition point for optimizations +such as mainloop fusions and epilogue fusions. +It is responsible for implementing +the `k_tile` loop in the above triply nested loop pseudocode. + +### Collective Mainloops + +The `cutlass::gemm::collective::CollectiveMma` class +is the primary interface to the collective +matrix multiply-accumulate (MMA) mainloops. +"Mainloop" refers to the "main loop" over tiles -- +the "cluster tile k" loop in the pseudocode +near the top of this document. +Any looping over multiple tiles that +the algorithm might need to do would happen here. + +The `CollectiveMma` class is declared in the header +[cutlass/gemm/collective/collective_mma.hpp](/include/cutlass/gemm/collective/collective_mma.hpp). + +```c++ +namespace cutlass::gemm::collective { + +template < + class DispatchPolicy, + class TileShape, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class TiledMma, + class GmemTiledCopyA, + class SmemLayoutAtomA, + class SmemCopyAtomA, + class TransformA, + class GmemTiledCopyB, + class SmemLayoutAtomB, + class SmemCopyAtomB, + class TransformB +> +struct CollectiveMma { + static_assert(sizeof(ElementA) == 0, "Could not find a mainloop specialization."); +}; + +} // namespace cutlass::gemm::collective +``` + +- `DispatchPolicy` is the most important type for a collective, and is +[covered in more detail below](#collective-dispatch-policies). + +- `StrideA` and `StrideB` are instances of type `cute::Stride` that represent the global memory layout of A and B tensors. These strides are required to be rank-3, representing the modes `[outer, inner, batch]`. Each of the 3 ranks can be a multi-modal hierarchical stride; this would apply if implementing a tensor contraction. + +- `TiledMma` is an instance of `cute::TiledMma`. + +- `GmemTiledCopyA` and `GmemTiledCopyB` are instances of `cute::TiledCopy` types. Both tiled operation types are [covered in more detail below](#tiled-mma-and-copy). + +- `SmemLayoutAtomA` and `SmemLayoutAtomB` are instances of type `cute::Layout` and represent the smallest +layout that will get tiled over the entire collective's shared memory. This layout does _not_ include the +pipeline mode, and therefore, both are expected to be rank 2 layouts of shape [`outer`, `inner`]. + +- `SmemCopyAtomA` and `SmemCopyAtomB` are `Copy_Atom`s to be used for moving data from shared memory +into register memory. + +Notice that CUTLASS 3.0 mainloops do not accept a dedicated accumulator element type. +We obtain the accumulator type from the `typename TiledMma::ValTypeC`. Note also that +top level API's `ElementA` and `ElementB` can defer from those of the MMA facing +`typename TiledMma::ValTypeA` and `typename TiledMma::ValTypeB`, allowing TMA or user +supplied transform operations to perform type conversions. + +### Collective Dispatch Policies + +`CollectiveMma` implementations are not generic. +Instead, they must be specialized for each algorithm and GPU architecture. +Users can dispatch to a `CollectiveMma` specialization +by picking template arguments matching that specialization. +CUTLASS 3.0 adopts a tag-based dispatch policy type to specialize +mainloop implementations and add tuning knobs to them. + +Below is an example of one of the dispatch policies that is used to dispatch to a Hopper TMA +warp-specialized mainloop implementation: + +```c++ +// n-buffer in smem (Hopper TMA), +// pipelined with Hopper GMMA and TMA, +// warp-specialized dynamic schedule +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelTmaWarpSpecialized +> +struct MainloopSm90TmaGmmaWarpSpecialized { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; +}; +``` + +The `Stages_` template parameter lets the user freely vary the number of pipeline stages, +while the `ClusterShape_` type allows for parameterization over the shape of the threadblock +cluster over which TMA multicast will take place. + +The collective dispatch policy is also the primary point of composing various kernel schedules +freely with any mainloop. Each mainloop policy either prescribes a `Schedule` with which +it needs to be run, or exposes a template API that lets the user pick a subset of the following schedules: + +```c++ +struct KernelMultistage { }; +struct KernelTma { }; +struct KernelTmaWarpSpecialized { }; +struct KernelTmaWarpSpecializedPersistent { }; +``` + +- A single kernel schedule can support multiple mainloop implementations. For example, +`KernelMultistage` can be composed with many different mainloop implementations across GPU +architectures such as `MainloopSm70TwoStage`, `MainloopSm80CpAsyncUnpredicated`, `MainloopSm90CpAsyncGmma`, and many more. + +- A single mainloop can be composed with multiple +possible kernel schedules. For example, the `MainloopSm90TmaGmmaWarpSpecialized` can be +composed with either the `KernelTmaWarpSpecialized` or `KernelTmaWarpSpecializedPersistent` +kernel schedules. + +As [discussed in the CUTLASS 3.0 design documentation](cutlass_3x_design.md), adopting tag +dispatch policies for our core vocabulary types allows us to maintain a single type name for +all operations that conceptually belong to the same class. This design has the following benefits. + +- It *avoids code duplication* in cases where mainloops can be composed with multiple kernels or vice versa. +- It *makes writing generic code easier*, as the primary type name `CollectiveMma` does not change across any implementation. +- It *provides a clear, singular extension point* for users to plug in new, custom mainloops implementations specialized on their own dispatch policies. + +### Collective Builder for `CollectiveMma`s + +The primary `CollectiveMma` is intended to be an expert user interface that allows full control over +all the properties of the collective's GPU micro-kernel. However, often a user just wants an +off-the-shelf GEMM mainloop implementation parameterized on simple configuration parameters. CUTLASS 3.0 +provides [`cutlass::gemm::collective::CollectiveBuilder`](include/cutlass/gemm/collective/collective_builder.hpp) for such scenarios. + +```c++ +namespace cutlass::gemm::collective { +template < + class ArchTag, + class OpClass, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType, + class Enable = void +> +struct CollectiveBuilder { + static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); +}; +} // namespace cutlass::gemm::collective +``` + +`CollectiveBuilder` accepts CUTLASS 2.x equivalent input template arguments, and attempts to build +the best performing `CollectiveMma` from the given parameters. + +- `ArchTag` is one of the SM architectures tags from `cutlass::arch::Sm*`. +- `OpClass` is one of the operator class tags from `cutlass::arch::Sm*`. +- `ElementA` and `ElementB` are the logical value types of the A resp. B tensors. +- `ElementAccumulator` is the accumulator type to be used in the instruction. +- `GmemLayoutA` and `GmemLayoutB` are CUTLASS 2.x layout tags, `layout::RowMajor` or `layout::ColumnMajor`. +- `AlignmentA` and `AlignmentB` are global memory alignments of A and B tensors in terms of element count. +- `TileShape_MNK` is an instance of `cute::Shape` that is rank-3, representing the MxNxK collective tile shape. +- `ClusterShape_MNK` is an instance of `cute::Shape` that is rank-3, representing the MxNxK threadblock cluster tile shape. +- `StageCountType` is either `collective::StageCountAuto` or an instance of `collective::StageCount`. +- `KernelScheduleType` is either `collective::KernelScheduleAuto` or one of the specific kernel schedule tags discussed in the [dispatch policy section](#collective-dispatch-policies) above. + +`StageCountAuto` allows the collective builder to compute the size of a single stage's size in shared memory +and maximize the shared memory usage assuming 1 threadblock / multiprocessor occupancy. + +`KernelScheduleAuto` allows the collective builder to pick the best kernel schedule available for the +given set of parameters, or let's the user override this with a specific kernel schedule type. + +Note that collective builders are still in beta, and their functionality +does not map onto the full design space that the primary expert `CollectiveMma` API +allows for. We expect their supported mainloop types to expand in future releases, but +with 3.0, only SM90 tensorop kernels are supported through the builder API. The builder API +may also change in the future as we adopt user feedback. + +If the builder is able to provide a collective mainloop type for the given set of parameters, +it will be aliased within as `CollectiveOp`. For more information on how to +parameterize kernels conveniently with the collective builder, please see example [49_hopper_gemm_schedules_with_collective_builder](49_hopper_gemm_schedules_with_collective_builder). + +### Epilogue + +The collective epilogue implements element-wise operations +involving the output matrix. Users can provide a custom +epilogue, or use one of the standard epilogues. +These live in the directory +[include/cutlass/epilogue/collective/](../../include/cutlass/epilogue/collective/), +and include classes like +`cutlass::epilogue::collective::DefaultEpilogue` +and +`cutlass::epilogue::collective::Epilogue`. +CUTLASS's provided collective epilogues +do not live under `include/cutlass/gemm` +or in the `cutlass::gemm` namespace, +because they can be used for computations +other than GEMM. + +## Kernel API + +The kernel is "a collection of all clusters in the grid." +The kernel layer schedules have four main responsibilities. + +- Ordering the execution of collectives within the kernel, performing any synchronization between that may be necessary +- Marshalling the threads of a warp specialized schedules into their respective roles +- Performing any necessary grid swizzling logic +- Tiling the input tensors with the threadblock cluster value tile before invoking the collectives on them + +The Kernel API is the entry point for a grid of thread blocks +that may or may not be organized in a cluster. +It is the composition point for fusing back-to-back GEMMs, +epilogues, and/or other operations. + +The entry point API for CUTLASS 3.0 kernel is the class +`cutlass::gemm::kernel::GemmUniversal`, found in the header file +[include/cutlass/gemm/kernel/gemm_universal.hpp](../../include/cutlass/gemm/kernel/gemm_universal.hpp). +`GemmUniversal` is a stateless universal device kernel +that implements GEMM as the composition of two parts: + +* a collective mainloop, and +* a collective epilogue + +```cpp +namespace cutlass::gemm::kernel { +/* + * Stateless universal device GEMM kernel type that treats GEMM as + * a composition of a collective mainloop and a collective epilogue. + * + * Supports both the 2.x and 3.x APIs based on whether the first type is + * a cute::tuple<> or not. + * 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h + * 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp + * + * In the following declaration, the name preceding the 'Or' refers to + * 3.x API type argument order, and the name succeeding the 'Or' refers to + * 2.x API type argument order. Template arguments without two names + * belong to the 3.x API only. +**/ +template < + class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l) + class CollectiveMainloopOrEpilogue_, + class CollectiveEpilogueOrThreadblockSwizzle_, + class GridSwizzle_ = void, + class Enable = void +> +class GemmUniversal; +} // namespace cutlass::gemm::kernel +``` + +*Stateless* means that the caller -- +for example, the Device API described above -- +manages the kernel's state. +The kernel just takes input and output parameters (`Params`). + +*Universal* means that `GemmUniversal` works +for both CUTLASS 3.0 and 2.x interfaces +and across a broad range of kernel schedules. +If `GemmUniversal`'s first template argument is a `cute::Shape`, +then `GemmUniversal` assumes that the remaining template arguments +implement the 3.0 APIs. Otherwise, `GemmUniversal` assumes that +the remaining template arguments implement the 2.x APIs. +Starting with CUTLASS 3.0, the problem shape has been promoted +to a top-level template API for the GEMM kernel. +This supports fully static GEMM instantiations +where the user expects to know some or all +of the problem shapes at compile time +in order to extract even more performance. + +The *collective mainloop* implements MMA on local tiles. +The *collective epilogue* addresses any operations after the MMA, +such as applying the `beta * C` part of `C := beta * C + alpha * A * B`. +We will explain *collective* in more detail below. + +Specializations of `kernel::GemmUniversal` for 3.0 APIs live in +any of various `gemm_*.hpp` files in the directory +[include/cutlass/gemm/kernel/](../../include/cutlass/gemm/kernel/). +Specializations for 2.x APIs can be found in the header file +[include/cutlass/gemm/kernel/gemm_universal.h](../../include/cutlass/gemm/kernel/gemm_universal.h). + +CUTLASS 3.x implements various embodiments of `kernel::GemmUniversal`. +Each kernel layer schedule is specialized +for a GEMM scheduling algorithm and GPU architecture. +Specializations of `kernel::GemmUniversal` for 3.0 APIs live in +any of various `include/cutlass/gemm/kernel/{arch_tag}*.hpp` files in the directory +[include/cutlass/gemm/kernel/](../../include/cutlass/gemm/kernel/). +Which specialization to dispatch to is decided through the dispatch policy's `Schedule` type. + +For example, the header file +[include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp](../../include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp) +has a specialization of `kernel::GemmUniversal` for Hopper +that uses a warp-specialized mainloop with a persistent scheduling algorithm, +while the header file +[include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp](../../include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) +has a specialization of `GemmUniversal` for Hopper +that uses a warp-specialized but non-persistent algorithm. + +To support composition between supported kernel schedules and mainloop dispatch policies without having to +duplicate collective mainloop implementations, GEMM kernel layer schedules can be composed with +any mainloop that specifies their corresponding kernel schedule as their `Schedule` type in the policy. +This is discussed in detail in the [collective dispatch policy section](#collective-dispatch-policies) above. + +```c++ +// An example of the SM90 KernelMultistage kernel's +// specialization logic that allows it to be composed +// with many mainloops such as `MainloopSm80CpAsync` +// and `MainloopSm70TwoStage`. +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class GridSwizzle_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + GridSwizzle_, + std::enable_if_t>> +``` + +## Device API + +The Device API is a universal, kernel-agnostic host interface +for kernel launch and managing the lifetime of +reusable host-side parameters. + +This API is how users' host-side .cu code +invokes CUTLASS's single-GPU GEMM kernels. +It serves the same purpose as cuBLAS and behaves similarly. + +The entry point for the Device GEMM API is the class +`cutlass::gemm::device::GemmUniversalAdapter`. +This class lives in the header file +[include/cutlass/gemm/device/gemm_universal_adapter.h](/include/cutlass/gemm/device/gemm_universal_adapter.h). +`GemmUniversalAdapter` is a stateful, reusable handle, +which is parameterized on the `cutlass::gemm::kernel` type. + +```c++ +/*! + GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel + of type cutlass::gemm::kernel::* + + It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs + to create it from the host facing arguments. For power users, new static methods + are exposed in 3.x APIs that bypass the stateful methods or args->params lowering. + + It supports kernel types that implement both the 2.x and 3.0 APIs, + however, this is done by specializing the implementation of GemmUniversalAdapter + on the two kernel API types, and thus, GemmUniversalAdapter's behavior might + differ between the two specializations. +*/ +template +class GemmUniversalAdapter; +``` + +*Stateful* means that the handle instance contains state +that the kernel needs to run. +This means that the user must initialize the handle first, +then use the initialized handle instance to run the kernel. +Statefulness also means that the handle can manage the lifetime +of the kernel's `Params` -- the parameters of the kernel itself. +An important duty of `GemmUniversalAdapter` +is to map from the user's `Arguments` -- +what the user sees as the kernel's parameters -- +to the `Params` that the kernel actually sees. +For power users, the class exposes new static methods +in 3.0 APIs that can bypass stateful methods +or go directly to `Params` without intermediate `Arguments`. + +*Reusable* means that the handle instance can be used +to call the kernel multiple times with different arguments +(e.g., different matrices). +Reusing the handle may be more efficient than just +creating a new handle for each kernel invocation. + +*Parameterized on the kernel type* means that +the `GemmUniversalAdapter` class' behavior +depends on the GEMM kernel type (see the next section). +Specifically, `GemmUniversalAdapter` has a template parameter +`GemmKernel`, which is the GEMM kernel type. +Valid template arguments for `GemmKernel` are + +* `cutlass::gemm::kernel::GemmUniversal`, + implementing CUTLASS 3.x API kernels; +* `cutlass::gemm::kernel::GemmUniversal`, + implementing CUTLASS 2.x API kernels; or +* Any valid CUTLASS 2.x `kernel` layer GEMM that + was previously composable with the `device::GemmUniversalAdapter`. + +`GemmUniversalAdapter` presents a single +host-side interface to both 3.0 and 2.x kernels. +CUTLASS accomplishes this by +specializing `GemmUniversalAdapter`'s implementation +on either the 2.x API implementing kernel layer GEMMs, or on the 3.x API +implementing kernel layer GEMMs. The metafunction [`cutlass::gemm::detail::IsCutlass3GemmKernel`](cutlass_3x_backwards_compatibility.md#kernel-api-design-differences) +is what `GemmUniversalAdapter` uses to distinguish between 2.x and 3.x kernels. + +`GemmUniversalAdapter` sets up and launches the kernel, using the +CUDA extended launch API for threadblock cluster support if required. +Note, `GemmUniversalAdapter` does *not* specify the grid shape. +The kernel controls the grid shape +and other kernel-specific launch parameters. +This makes it possible for all 3.0 kernels +to use the same kernel launch code, +thus factoring out kernel launch from the actual kernel. + +## Tiled MMA and Copy + +The Tiled MMA or Copy are tilings of MMA atoms resp. Copy atoms +across threads and data, with possible permutations applied to the +resulting tiling. This layer is most analogous to the warp level +tiling of MMA instructions in CUTLASS 2.x. However, it views the tiling +from the perspective of all threads participating in the operation +and generalizes the concept to copy operations as well. The purpose +of this layer is to build composable GPU micro-kernels out of a plethora +of hardware accelerated math and data movement operations, each with their +unit layouts in threads and data. The tiled MMA and Copy types present +all these various hardware accelerated CuTe Atoms with a single, consistent +API. + +The resulting tiled operation acts as a single MMA or copy operation +that users can invoke in the "inner" loop +of the three-nested-loops pseudocode +at the top of this document using `cute::gemm()` or `cute::copy()`. + +We call this API "tiled" because it constructs +larger operations out of the Atoms provided by CuTe, +as if fitting together individual tiles +to build a reusable component of a mosaic. +For example, CuTe might provide an MMA Atom +that users can call on a single warp, +for fixed M, N, and K dimensions. +CUTLASS can then use CuTe operations like `make_tiled_mma` +to turn this Atom into an operation +that works on an entire thread block, +for larger M, N, and K dimensions. + +## Atom API + +An "Atom" is the smallest collection of threads and data +that must participate in the execution of a hardware-accelerated +math or copy operation. + +An Atom is "atomic" (indivisible) not in the sense of +concurrent memory operations like `atomicAdd` +(which are "indivisible in time (causality)"), +but in the sense of indivisibility in "space" -- +the number of values and the groups of parallel workers +that must participate in the operation together. + +An Atom uses CuTe Layouts to express the required +dimensions and strides of its input and output arrays. +Generally these are fixed at compile time. + +The Atom API wraps calls to actual hardware instructions +that accelerate MMA or copy operations. +Users can ask for GPU architecture-specific implementations, +or just pick generic implementations and rely on +whatever GPU architectures were enabled. + +For more information about Atoms, +please refer to CuTe's tutorial, e.g., the sections on + +* [algorithms](./cute/04_algorithms.md) like `gemm` and `copy`, + +* [MMA Atoms](./cute/0t_mma_atom.md#cute-mma-atoms), and + +* [a GEMM example](./cute/0x_gemm_tutorial.md). + +# Copyright + +Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/media/docs/layout.md b/media/docs/layout.md index f8e21da048..eb68abcdb3 100644 --- a/media/docs/layout.md +++ b/media/docs/layout.md @@ -2,6 +2,11 @@ [README](/README.md#documentation) > **Layouts and Tensors** +Note: This document talks about CUTLASS 2.x layout tag types. +CUTLASS 3.0 deprecates all legacy 2.x layout tags in favour of a single `cute::Layout` +vocabulary type for all thread and data tensors. Please refer to the +[documentation for cute layouts](media/docs/cute/01_layout.md) for more details about CUTLASS 3.0's definition of "layout". + # Layouts and Tensors _Tensors_ are mathematical objects represented by a multidimensional array of numeric elements in memory. diff --git a/media/docs/pipeline.md b/media/docs/pipeline.md new file mode 100644 index 0000000000..ccf8385953 --- /dev/null +++ b/media/docs/pipeline.md @@ -0,0 +1,210 @@ +# Synchronization primitives + +## Overview of CUDA's synchronization methods + +The CUDA programming model provides 3 abstractions: + +* hierarchical parallelism -- that is, parallel threads + grouped into hierarchical units such as blocks and clusters; + +* shared memory, through which parallel threads that are + in the same hierarchical unit can communicate; and + +* synchronization methods for threads. + +These abstractions help developers extract +both fine-grained and coarse-grained parallelism, +by making it possible for them to subdivide problems +into independent components, +and to insert synchronization at appropriate points. + +Over the years CUDA has introduced several synchronization primitives +that operate at different levels of the hierarchy. +These include + +* [thread block - level](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#synchronization-functions) synchronization (e.g., `__syncthreads()`); + +* [warp-level](https://developer.nvidia.com/blog/using-cuda-warp-level-primitives/) synchronization (e.g., `__syncwarp()`); and + +* [thread-level](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#memory-fence-functions) fence operations. + +As an extension to this, starting with the Hopper architecture, CUDA added the following improvements: + +* [thread block clusters](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#thread-block-clusters) -- + a new level in the thread hierarchy representing + a group of thread blocks that can coordinate and share data; + +* synchronization instructions for a thread block cluster and threads within a cluster scope. + +## CUTLASS's abstractions for Hopper features + +CUTLASS now includes abstractions +for the following features introduced in Hopper. + +1. Thread block cluster - level synchronization and query + [APIs](/include/cute/arch/cluster_sm90.hpp) + +2. Abstractions for new + [barrier instructions](/include/cutlass/arch/barrier.h) + which help with efficient synchronization + of threads within a thread block cluster. + +### Asynchronous pipelines + +In order to write a performant GEMM Kernel, +software pipelining is critical to hide the latency of global memory loads. +(Please refer to the +[Efficient GEMM](/media/docs/efficient_gemm.md#pipelining) document.) +Different threads or groups of threads +may have different roles in the pipeline. +Some are "producers" that load data or perform computations +to satisfy other threads' input data dependencies. +The same or different threads may be "consumers" +that do other work with those input data dependencies, +once they are satisfied. +Starting with the Hopper architecture, +the presence of hardware-accelerated synchronization instructions +make it possible for "producer" and "consumer" threads +to communicate with each other efficiently +about their data dependencies. + +Implementing a persistent GEMM algorithm calls for managing +dozens of different kinds of asynchronously executing operations +that synchronize using multiple barriers organized as a circular list. +This complexity is too much for human programmers to manage by hand. +As a result, we have developed +[asynchronous Pipeline classes](/include/cutlass/pipeline.hpp). +These classes help developers orchestrate a pipeline +of asynchronous producer and consumer threads, +without needing to worry about lower-level hardware details. +These classes serve a similar function as the various +[pipeline abstractions](https://nvidia.github.io/libcudacxx/extended_api/synchronization_primitives/pipeline.html) +in libcu++. + +#### Pipeline methods + +##### Producer acquire + +The `producer_acquire` method is to be used by asynchronous producer threads +before issuing other instructions associated with a particular pipeline stage +(e.g., copy or write). + +This is a blocking instruction +which blocks further execution of consumer threads +unless the particular stage waiting to be acquired +is released by a consumer. + +We say that a pipeline at its start is "empty" if producer threads are free to produce and do not need to wait for a consumer release -- that is, if an acquire operation is expected to succeed. If the pipeline at its start is empty, then we can either skip performing producer acquire operations during the first pass through the pipeline stages, or use the `make_producer_start_state` method. The latter ensures that the acquire operation will succeed at the start of a pipeline. + +##### Producer commit + +The `producer_commit` method is to be issued by asynchronous producer threads +after the instructions associated with a particular stage +(e.g., shared memory writes) have completed, +in order to notify the waiting asynchronous consumer threads. +This is a nonblocking instruction. + +This API may result in a No-Op in some cases, +if the producer instructions also update the barrier stage associated automatically +(e.g., TMA_based producer threads using the `PipelineTmaAsync ` class). + +##### Consumer wait + +The `consumer_wait` method is to be used by consumer threads +before consuming data from a particular pipeline stage +which is expected to be produced by producer threads. + +This is a blocking instruction. That is, +until the producer threads have committed to a particular stage, +this instruction is expected to block further execution of consumer threads. + +##### Consumer release + +The `consumer_release` method is to be used by consumer threads +to signal waiting producer threads that they have finished consuming data +associated with a particular stage of the pipeline. +This is a nonblocking instruction. + +#### Pipeline example + +```c++ +// 4-stage Pipeline +static constexpr int NumStages = 4; +using MainloopPipeline = typename cutlass::PipelineAsync; +using PipelineState = typename cutlass::PipelineState; + +// 2 producer threads and 1 consumer thread +typename MainloopPipeline::Params params; +params.producer_arv_count = 2; +params.consumer_arv_count = 1; +MainloopPipeline pipeline(shared_storage.storage, params); + +// Producer threads +if (thread_idx == 0 or thread_idx == 1) { + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + for ( ; iter > 0; --iter) { + pipeline.producer_acquire(smem_pipe_write); + + // Producer ops + // If any memory operations are involved, then we also need + // to guarantee that writes are completed and visible to consumer(s). + + pipeline.producer_commit(smem_pipe_write.index()); + ++smem_pipe_write; + } +} +else if (thread_idx == 2) { + PipelineState smem_pipe_read; + for (; iter > 0; --iter) { + pipeline.consumer_wait(smem_pipe_read); + + // Consumer ops + + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + } +} +``` + +In this example, we create an instance of the asynchronous pipeline class `PipelineSync`, +and then synchronize among 3 asynchronously executing threads: +2 producer threads and 1 consumer thread. + +Please note that this is a basic example. +There are different versions possible, +depending on what the producer and consumer threads are doing. +Please refer to our [unit tests](/test/unit/pipeline) +and the other [pipeline classes](/include/cutlass/pipeline.hpp) +for more details. + +# Copyright + +Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/media/docs/profiler.md b/media/docs/profiler.md index a841a80757..b8e409fc05 100644 --- a/media/docs/profiler.md +++ b/media/docs/profiler.md @@ -13,7 +13,7 @@ The CUTLASS Profiler may be compiled with: $ make cutlass_profiler -j ``` -To limit compilation time, only one tile size (typically 128x128) is instantiated for each data type, +To limit compilation time, only one tile size (typically 128x128) and threadblock cluster size (typically 2x1x1) is instantiated for each data type, math instruction, and layout. To instantiate all sizes, set the following environment variable when running CMake from an empty `build/` directory. ```bash @@ -168,8 +168,8 @@ Example: The CUTLASS Profiler is capable of executing GEMM and Sparse GEMM problems. The CUTLASS Profiler can be built with cuBLAS enabled to use as a reference implementation. If CMake detects -the cuBLASS library available in the system, it is included as a dependency. This may be explicitly overridden -with CMake flag `CUTLASS_ENABLE_CUBLAS`. +the cuBLAS library available in the system, it is included as a dependency. This may be explicitly overridden +with CMake flag `CUTLASS_ENABLE_CUBLAS`. ## GEMM Arguments @@ -197,6 +197,9 @@ GEMM [int] --cta_m,--threadblock-shape::m Threadblock shape in the M dimension. [int] --cta_n,--threadblock-shape::n Threadblock shape in the N dimension. [int] --cta_k,--threadblock-shape::k Threadblock shape in the K dimension. + [int] --cluster_m,--cluster-shape-shape::m Cluster shape in the M dimension. + [int] --cluster_n,--cluster-shape-shape::n Cluster shape in the N dimension. + [int] --cluster_k,--cluster-shape-shape::k Cluster shape in the K dimension. [int] --stages,--threadblock-stages Number of stages of threadblock-scoped matrix multiply. [int] --warps_m,--warp-count::m Number of warps within threadblock along the M dimension. [int] --warps_n,--warp-count::n Number of warps within threadblock along the N dimension. @@ -342,7 +345,50 @@ To faclitate generation of pivot tables and charts, additional columns may be pr $ ./tools/profiler/cutlass_profiler --kernels=cutlass_simt_sgemm_128x128_nn \ --m=3456 --n=4096 --k=8:4096:8 --output=report.csv \ --tags=cutlass:2.2,date:2020-06-08 -``` +``` + +## CUTLASS 3.0 GEMM procedural names + +CUTLASS 3.0 introduces a new naming convention for GEMMs used by the profiler targeting the NVIDIA +Hopper architecture and beyond so as to indicate new features of the kernel within the name +(e.g., the cluster shape). + +To best illustrate this naming convention, we will walk through the meaning of each of the components +in a GEMM kernel used by the profiler: +``` +cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f32_128x128x64_2x1x1_0_ntn_align8 +``` + +The components within this name are as follows: + +* `cutlass3x`: indicates that the kernel was generated through the CUTLASS 3.0 API +* `sm90`: indicates that the kernel targets NVIDIA GPUs with compute capability 90 +* `tensorop`: indicates that the kernel makes use of NVIDIA Tensor Cores +(as opposed to `simt`, which indicates the use of "CUDA cores") +* `s`: indicates that the Tensor Core instruction being used accumulates in single precision +(as opposed to `h`, which indicates half precision) +* `64x128x16gemm`: indicates that the shape of the Tensor Core instruction being used (MxNxK) is 64x128x16 +* `f16_f16_f32_f16`: indicates that the data types for operands A, B, and C are each `f16` +(half precision) and that accumulation is performed using `f32` (single precision) +* `128x128x64`: indicates that the thread block shape used in the GEMM (MxNxK) is 128x128x64 +* `2x1x1`: indicates that the cluster shape being used is 2x1x1 +* `0`: indicates that the kernel uses the CollectiveBuilder's automatic stage calculation to determine the +number of pipeline stages in the kernel. Note that `0` does not mean that no stages are used. A nonzero value indicates that automatic stage calculation is not performed and indicates the number of pipeline stages to be used. +This 0 is only added to the kernel's procedural name, the profiler will still report the actual stage count +when printing the kernel argument details (`--stages=N`) and kernel discovery will still support filtering through the `--stages` argument. +* `ntn`: indicates that the layouts for operands A, B, and C are column major ("n"; non-transposed), +row major ("t"; transposed), and column major, respectively. +* `align8`: indicates that the maximum alignment between operands A and B is 8. + +Note that in some special cases where the input A/B types do not match that of the MMA +instruction's, the MMA facing input type is added to the instruction string as well. + +``` +cutlass3x_sm90_tensorop_s64x128x8tf32gemm_f32_f32_f32_f32_128x128x32_2x1x1_0_tnn_align4 +``` + +* `s64x128x8tf32gemm`: indicates that the MMA consumes inputs in `tf32` format, and therefore +the kernel performs rounding of the `f32` values in global memory while loading them into shared memory. # Convolution diff --git a/media/docs/programming_guidelines.md b/media/docs/programming_guidelines.md index 59c2f57f0c..8e454fa42f 100644 --- a/media/docs/programming_guidelines.md +++ b/media/docs/programming_guidelines.md @@ -6,32 +6,23 @@ ## Hierarchical Organization -CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvlabs.github.io/cub/) -for expressing collective operations. Objects expose an interface for a problem that is then decomposed -into concurrent subtasks executed by cooperating threadblocks, warps, and threads. For example, a grid-level -object may be constructed with base pointers to the start of a GEMM operation, add a threadblock-dependent -offset to partition the problem, and then compute a per-threadblock GEMM. This in turn performs some -operations as a collection of cooperating threads, while it may partition other parts of the task into -warp-level subtasks. - -Consequently, CUTLASS components are organized by the computation then by the layer of -the following hierarchy. - -* *device*: an operation is _device-wide_ and may launch one or more kernels on the GPU -* *kernel*: an operation is implemented by a CUDA kernel with definitions for `__shared__` memory and constant memory allocations -* *threadblock*: an operation is collectivey executed by a threadblock; any component calling `__syncthreads()` is likely to be threadblock-scope -* *warp*: an operation is collectively executed by a warp; threads within the context of a warp are referred to as _lane_ -* *thread*: an operation is performed by an individual thread with no other data sharing or interaction with other threads -* *instruction*: an operation corresponds to an individual hardware or PTX instruction +The [CUTLASS 3.0 GEMM API](./gemm_api_3x.md) document +explains CUTLASS 3.0's hierarchical organization, +based conceptually on parallelization strategy. +This differs from CUTLASS 2.x's approach, +which more closely mirrors the GPU hardware hierarchy +of thread blocks, warps, and threads. ## Design Patterns -CUTLASS strives to achieve the highest performance possible on NVIDIA GPUs while also offering a -flexible composition that an be easily applied to solve new problems related to Deep Learning and -linear algebra. Though we intend to make CUTLASS as simple and straightforward as possible, given -a tradeoff between simplicity and performance, CUTLASS chooses performance. Consequently, several -design patterns are necessary to yield a composable structure while also satisfying these performance -objectives. This section is intended to provide more detail. +CUTLASS aims for the highest performance possible on NVIDIA GPUs. +It also offers flexible components that can be assembled and customized +to solve new problems related to deep learning and linear algebra. +Given a tradeoff between simplicity and performance, +CUTLASS chooses performance. +Consequently, several design patterns are necessary +to yield a composable structure +while also satisfying these performance objectives. ### Templates @@ -75,8 +66,9 @@ objects for each data member. To be consistent, this pattern defines a convention in which classes define internal shared memory storage requirements. Classes should consider all SharedStorage structures to be opaque other than their own child class. When the lifetimes -of child objects are known to be non-overlapping, unions may be used to alias multiple SharedStorage objects to the same -shared memory region and reduce overall SMEM capacity. +of child objects are known to be non-overlapping, `union`s may be used to alias multiple SharedStorage objects to the same +shared memory region and reduce overall shared memory capacity. Developers should carefully note that C++ `union` rules +require that they only access the most recently written ("active") member of the `union`; this differs from C rules. ### Loop Unrolling @@ -104,123 +96,578 @@ for (int idx = 0; idx < kN; ++idx) { // Loop has constant number of iterati ## Style -### C++ Style +### No automatic code formatting -CUTLASS source code follows the -[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html) with exceptions and extensions. +Do not use any kind of automatic code formatting, +like `clang-format`, on CUTLASS code. -Design choices should be consistent with the -[CppCoreGuidelines](https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md) recommendations by Stroustrup and Sutter. +### C++ style -### CUDA Built-in Variables +#### CUTLASS is a C++ project -Avoid direct access to CUDA built-in variables `threadIdx`, `blockIdx`, `blockDim`, and `gridDim` within -CUTLASS components except in special circumstances. +CUTLASS is a C++ project. CUDA C++ is a C++ dialect. +Therefore, we write using standard C++ idioms as much as possible. +We aim for portability to as many compilers as possible, +by writing host code in Standard C++ +and device code in CUDA C++ +that resembles Standard C++ as much as possible. +This improves usability +for the general community of C++ developers, +and makes it easier for new staff to join the project. -Using built-in 'global' variables directly within resuable components necessitates that all components -use them consistently which may not be possible if CUTLASS components are used in other contexts. +#### Follow Standard C++ idioms where possible -Instead, components should accept a linear ID identifying threads, warps, and threadblocks from calling -code. The top-level kernel may then decide how to map threads, warps, and blocks to the problem it is -solving. +Regarding "standard C++ idioms," +CUTLASS source code follows the following guidelines, +with deviations only because of compiler limitations +or where performance absolutely requires it. +"Performance requires it" implies measurement. +Deviations should be limited in scope +and we should always strive to eliminate them. -### Use CUTLASS Fundamental Types +* [C++ Core Guidelines](https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md) -Use the [fundamental types](fundamental_types.md) defined in CUTLASS consistently. Doing so contributes -to a framework of interoperable, consistent components. +* [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html) -In particular, be sure to use: +#### Spacing and line length -* [Numeric types](fundamental_types.md#numeric-types) to represent numeric data in host and device code -* [Containers](fundamental_types.md#containers) to store data in register-backed arrays -* [functional.h](fundamental_types.md#functional) to perform numeric operations in generic code -* [Layouts](layout.md) to store stride and partially specialize template classes -* [`TensorRef` and `TensorView`](layout.md#tensorref) to pass pointers and layout objects +* Use spaces, not tabs. -Avoid defining alternative implementations of the same functionality. Instead, prefer to enhance -or extend additional components where it makes sense. +* Use 2 spaces to indent. -### Classes and Structs +* Max 100 characters per line. -Type names use `CapitalLetters` except when implementations are a _perfect_ drop-in replacement for -Standard Library components. +Lines longer than 100 characters typically wrap unfavorably +when viewed in Github's pretty printer. -Follow the [CppCoreGuidelines](https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-struct) -to decide whether to use `class` or `struct`. Namely, -* use `class` when the object must maintain an invariant. Data members related to the invariant should be private. -* use `struct` when the class has no invariant to maintain, and data members may vary arbitrarily. +#### Function indentation -### Class Members +When calling a function or function object with a long name, +break the line right after the invoking open parenthesis. +Here is an example. -Methods and members are written using `snake_case`. +```c++ +detail::very_long_function_object_name{}( + params.long_parameter_name, some_operator.another_long_function_name()); +``` -Private data and function members have suffix `_`. +When declaring functions, indent function parameters like this. + +```c++ +void possibly_an_unusually_long_function_name( + std::uint32_t foo + std::uint32_t const* bar, + TypeA a, + TypeB b, + TypeC c) +{ + // ... the function's body ... +} +``` + +For function definitions only, +break the line between the parenthesis +that closes the function's parameters, +and the curly bracket +that opens the function's body. + +#### If-else brackets and spacing + +* Always use braces with conditionals such as `if`. + +* Use a space after control flow keywords + such as `if`, `for`, and `while`. + +* Use a space after the parenthesis closing a conditional + such as `if`, and the curly bracket opening a scope. + +* Use a new line between the closing brace + of an `if` branch, and the `else` keyword. + +```c++ +if (condition) { + // ... code ... +} +else { + // ... other code ... +} + +for (int k = 0; k < num_iters; ++k) { + // ... still more code ... +} +``` + +#### East const + +CUTLASS uses the +["East const"](http://slashslash.info/2018/02/a-foolish-consistency/) +convention. +That is, the `const` or `constexpr` keyword +goes after the type, not before. +The general rule is that `const` or `constexpr` +modifies the type to the left of it. +Here are some examples. + +```c++ +float constexpr compile_time_constant = 42.3f; + +float const const_float = /* whatever */; +float const& reference_to_const_float = const_float; +float const* pointer_to_const_float = &const_float; +float const* const const_pointer_to_const_float = &const_float; + +float nonconst_float; +float& reference_to_nonconst_float = nonconst_float; +float* pointer_to_nonconst_float = &nonconst_float; +float* const pointer_to_nonconst_float = &nonconst_float; +``` + +Contrast this with "West const" style, e.g., + +```c++ +const float const_float = /* whatever */; +const float* pointer_to_const_float = &const_float; +``` -### Constant names +#### Alignment of reference and pointer types -CUTLASS makes extensive use of constants and compile-time evaluation. Constant variable names should have -prefix `k` and use mixed case. True compile-time constsants should be defined as `constexpr` to enable -dependent `constexpr` functions. +For reference and pointer types, +align the `&` resp. `*` flush against the type +that it modifies. This is called "left alignment." -CUTLASS uses ["East const"](http://slashslash.info/2018/02/a-foolish-consistency/) style, placing `constexpr` keyword -after the type name. +For example, do this: ```c++ -float constexpr kPi = 3.14159f; +int const& var; +int const* var; ``` -### Class Member Order +and not this. + +```c++ +int const &var; +int const *var; +``` + +#### Avoid calling functions "fast" or "optimized" + +Putting words like "fast" or "optimized" +in the name of a function +assumes that the "fast" path is actually faster. +That might be true now, but later changes +(in the code, compilers, or GPU hardware) +might make it false. In that case, +your name could be unintentionally misleading. +Consider instead a name that briefly describes +the algorithm or feature that is relevant for optimization. +For example, `compute_on_host` is more meaningful +than `compute_slowly`, and computing on host +might be faster in some cases +(e.g., if the data are already on host +and the algorithm is not GPU-friendly). + +CUTLASS code has not always followed this rule in the past. +Some functions and classes might have words like "fast" in their name. +New code should follow this rule, however. + +#### Avoid creating unconstrained templated functions with common names + +See [C++ Core Guidelines T.47](https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#t47-avoid-highly-visible-unconstrained-templates-with-common-names): +"Avoid highly visible unconstrained templates +with common names." +Argument-dependent lookup (ADL) means that +if users call a function name without specifying the namespace, +the compiler can find overloads +of that function in any namespace. +This can lead to ambiguous overloads in users' code, +just because they happened to include one of your header files +that exposes an unconstrained function template. +The following illustrates this +with an unconstrained swap overload in the `cutlass` namespace. + +```c++ +#include +#include +#include + +// Uncomment the line below to observe unwarranted build errors. +//#define BAD_CUTLASS_SWAP 1 + +namespace cutlass { +struct Bar { + float f; +}; +} // namespace cutlass + +#ifdef BAD_CUTLASS_SWAP +namespace cutlass { + +template +void swap(T& a, T& b) // don't do this +{ + T tmp = a; + a = b; + b = tmp; +} + +} // namespace cutlass +#endif // BAD_CUTLASS_SWAP + +namespace other { + +#ifdef BAD_CUTLASS_SWAP +using cutlass::swap; +#endif // BAD_CUTLASS_SWAP + +// Imagine for the sake of this example +// that "foo" is a less common name, +// and that T is constrained via +// std::enable_if or a requires clause. +template +void foo(T& a, T& b) +{ + // The usual idiom for using std::swap is the "swap two-step": + // + // 1. import std::swap into the current scope, then + // 2. call swap without namespace qualification. + // + // That won't build if we have another swap + // overload available in the scope already. + + using std::swap; + swap(a, b); // OBSERVE UNWARRANTED BUILD ERROR HERE +} + +} // namespace other + +int main() +{ + int x = 42; + int y = 43; + other::foo(x, y); + assert(x == 43); + assert(y == 42); + + cutlass::Bar a{42.0}; + cutlass::Bar b{43.0}; + other::foo(a, b); + assert(a.f == 43.0); + assert(b.f == 42.0); + + // GCC 7.5 std::unique_ptr::reset calls swap, + // leading to the same issue as above. + // GCC 12.2's implementation of std::unique_ptr + // does not have this issue. Nevertheless, + // breaking the swap two-step will break users' code, + // just by them happening to include your headers. + auto ptr = std::make_unique(cutlass::Bar{666.0f}); + ptr.reset(new cutlass::Bar{777.0f}); // OBSERVE UNWARRANTED BUILD ERROR HERE + + return 0; +} +``` + +#### Function return values and in-out parameters + +##### Prefer return values to output parameters + +In general, avoid in-out mutable references to return a value. +If you need to return multiple values, +you can return them by `struct` or `tuple`, +rather than by output references. +This includes the special case of error reporting +by returning either a value or an error code. +Please see the next section for details. + +```c++ +// Instead of passing in-out mutable references ... +void not_preferred(float& input_and_output); // not preferred + +// keep functions pure and return value types instead +float preferred(float input); // preferred +``` + +##### Return multiple values by struct or tuple + +Sometimes a function needs to return multiple values. In that case, consider the following, in decreasing order of preference. + +1. Return a `struct`. This lets you name the fields + (for more self-documenting code), + yet still permits use of structured binding. + +2. Return a `tuple`. If you need a tuple type + that works on device, use `cute::tuple`. + (Please note that `cute::tuple` does not work + for all the types that work in `std::tuple`. + CuTe's documentation explains.) + +Here is an example of the struct approach for named values. +For a comparable example in the C++ Standard, +please see [`std::allocate_at_least`](https://en.cppreference.com/w/cpp/memory/allocate_at_least), +which returns `std::allocation_result`. + +```c++ +struct my_computation_result { + float value = 0.0f; + float relative_error = 0.0f; + bool success = false; +}; + +my_computation_result my_computation(float tolerance); + +void foo(float tolerance) +{ + // Approach 1: Use structured binding. The names + // you choose on the left-hand side have nothing + // to do with the struct, so it's up to you + // to get the order right. On the other hand, + // this code works whether my_computation returns + // a struct or a tuple. + auto [val, rel_err, ok] = my_computation(tolerance); + + // Approach 2: Keep the struct and use its named fields. + // This approach prevents errors like mixing the order of return types. + // However, it only works for structs, not for tuples. + + auto result = my_computation(tolerance); + if (not result.success) { + // computation did not succeed + } + else if (result.relative_error > tolerance) { + // successful but relative error too large + } + else { + // successful and relative error is in bounds + } +} +``` + +##### Reporting errors from a function that returns one or more values + +We may want to return one or more values +from a function that could fail +or otherwise report errors. +That is, the function either + +* returns one or more valid values, or + +* does not return any values and reports an error, + +but NOT BOTH. We contrast this with cases +when it's meaningful to report both a result +and whether the result is satisfactory. +For example, when solving +a system of nonlinear equations iteratively, +users may want the approximate computed solution, +even if the iteration did not succeed +by converging to the desired tolerance +in the desired number of steps. +(Users may want to invest more steps, +or use the current approximation +to jump-start a different algorithm.) + +We're talking here about the "either valid value(s), +or error, but not both" case. +For this case, C++ offers a few options. + +1. Return the value(s), or throw an exception on error + +2. `std::expected` (requiring C++23) or something like it + +3. `std::optional` (for a Boolean error state) + or something like it + +4. `std::variant` (a C++17 fall-back for `std::expected`) + or something like it + +5. C-style interface: return an error code, + and "return" the values as output parameters + +We usually cannot or do not want to +throw exceptions on device. +Some code projects forbid exceptions entirely +(on host or device) +and tell the compiler to disable them. +If we exclude a C-style interface (the last option) +as not idiomatic C++, then for host-only code, +`std::expected`, `std::optional`, and `std::variant` +all work. +For code that needs to build and run on device, +we can fall back to libcu++ equivalents +in the `cuda::std::` namespace, when they exist. +Otherwise, we must resort to returning a struct or tuple +with the value and the error information, +and ask users not to use the value on error. +This is acceptable if the value can be constructed +cheaply with a reasonable default. + +##### Performance of different value-or-error reporting methods + +[P1886R0](https://wg21.link/P1886R0) +(Ben Craig, "Error speed benchmarking") +surveys different ways in Standard C++ +to report errors from a function +that returns one or more values, +and compares their (host-only) performance +with different compilers. + +##### Use aggregate initialization when returning a struct or tuple + +Use aggregate initialization when returning a struct or tuple. +This avoids duplication of the return type name. + +```c++ +struct foo_result { + float value = 0.0f; + float error = 0.0f; + bool success = false; +}; + +foo_result foo(std::span input) +{ + // ... code ... + + // Prefer this. We know what type the function returns. + return {val, err, ok}; // prefer this + + // Naming foo_result again here is unnecessary. + // return foo_result{val, err, ok}; +} +``` + +However, note that this won't work if the function returns `auto`. +The general rule is to avoid code duplication. + +```c++ +auto foo(std::span input) +{ + // ... code ... + + if constexpr (some_condition) { + return foo_result{val, err, ok}; + } + else { + return bar_result{val, err, ok}; + } +} +``` + +##### Prefer using the actual return type to auto, if you know the type + +C++ lets you use `auto` to deduce the type returned from a function. + +* If you know the actual type, prefer using the type instead of `auto`. + +* Use [Constructor Type Argument Deduction](https://en.cppreference.com/w/cpp/language/class_template_argument_deduction) + (CTAD) if you know that a function returns some type + (e.g., `Tensor`), but don't know the type's template arguments. + +* Use `auto` in structured bindings (where you have to use it anyway). This also makes your code agnostic of whether the return type is a `struct`, `tuple`, `pair`, or other tuple-like type. + +* Be careful using `auto` with types that provide expression templates. + +Contrast this with "Almost Always Auto" (AAA) style. +We deliberately choose not to follow AAA style, +for the following reasons. + +* Using the actual type when we know it can help prevent common loss-of-precision errors in mixed-precision computations, an important use case for CUTLASS. + +* CTAD gives us much of the brevity of AAA, with more clarity. + +* Using the actual type instead of `auto` can prevent common dangling errors with expression templates. + +#### Classes and structs + +Type names use `CamelCase`. +That is, words start with capital letters. +The remaining letters in the word are lower case, +and words are joined with no intervening underscores. +The only exception is when implementations are +a drop-in replacement for C++ Standard Library components. + +Follow the +[C++ Core Guidelines](https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-struct) +to decide whether to use `class` or `struct`. + +* Use `class` when the object must maintain an invariant. + Data members related to the invariant should be `private`. + +* Use `struct` when the class has no invariant to maintain, + and data members may vary arbitrarily with respect to each other. + +Prefer nonmember functions and statelessness where possible. +Member functions imply invariants. +More invariants make code maintenance and testing harder. + +#### Class members + +Methods and members are written using `snake_case`. + +Private data and function members have suffix `_`. + +#### Class Member Order Members within classes and structures should be organized as follows: 1. Type and constant definitions + 2. Data members + 3. Constructors + 4. Other methods -This convention follows the [CUB library](https://nvlabs.github.io/cub/) and is also described by -[Howard Hinnant](https://howardhinnant.github.io/classdecl.html). Unsurprisingly, it approximates -the usual ordering of chapters in a typical Systems and Controls textbook. That is, -(1.) identify relevant constants, (2.) define a state-space representation of the dynamical system -under study (i.e. the data members), and (3.) devote subsequent chapters to definining dynamical behavior -of the system (i.e. the methods). +This convention follows the +[CUB library](https://nvlabs.github.io/cub/) +and is also described by +[Howard Hinnant](https://howardhinnant.github.io/classdecl.html). +It also approximates the usual ordering of chapters +in a typical Systems and Controls textbook. +That is, it + +1. identifies relevant constants, + +2. defines a state-space representation + of the dynamical system under study + (the class's data members), and then + +3. devotes the remaining "chapters" to defining + the system's dynamical behavior + (the class's methods). + +Here is an example class. -_Example_: ```c++ class A { public: - // Type definitions + // type definitions protected: - // protected Type definitions + // protected type definitions private: - // private Type definitions + // private type definitions public: - // Data members + // data members protected: // protected data members + // STRONGLY TO BE AVOIDED; + // please see C++ Core Guidelines private: // private data members public: - // Methods + // methods protected: // protected methods private: // private methods - }; - ``` -### File Names - -Files should be named using `snake_case` with extension `.h` for header files, `.cu` for CUDA sources, -and `.cpp` for C++ host-only source files. +#### Use scoped enums -### Use scoped enums - -Use scoped enums added in C++11 for enumerated types. Use capital letters for the enumerated type name +Use scoped enums (a C++11 feature) for enumerated types. +Use capital letters for the enumerated type name and prefix `k` for enumerators like other constants. ```c++ @@ -232,63 +679,129 @@ enum class MatrixOperation { }; ``` -### Namespaces +#### Namespaces -Namespaces are all lower case. The top-level namespace is `cutlass::`. The second nested namespace refers -top the general category of operation performed by its members, and the third nested namespace refers to -the CUDA execution model scope (if applicable). +Namespaces are all lower case. +The top-level namespace is `cutlass::`. +The second nested namespace refers to +the general category of operation +performed by its members: e.g., `gemm::`. +The third nested namespace refers to +the operations' position in the conceptual hierarchy: +e.g., `device::`, `kernel::`, or `collective::`. -The bodies of namespace definitions should not be intented, and comments on the closing brace are welcome. +The bodies of namespace definitions should not be indented. +Comments on the closing brace to indicate +the namespace being closed are welcome. ```c++ namespace cutlass { namespace gemm { -namespace warp { - -struct MmaTensorCore { +namespace kernel { +struct AnotherGemmKernel { + // ... contents ... }; -} // namespace warp +} // namespace kernel } // namespace gemm } // namespace cutlass ``` -### Macros +#### File Names + +New files should be named using `snake_case` +with extension `.hpp` for header files, +`.cu` for CUDA sources, +and `.cpp` for C++ host-only source files. + +Header files with extension `.h` +are CUTLASS 2.x legacy headers. -Avoid defining macros except where preprocessing is obligatory. In particular, -avoid using macros for constants. +#### Macros -Several existing macros defined in `cutlass/cutlass.h` are useful for working around compiler-dependent -behavior. +Only use macros when the preprocessor +is the only way to accomplish the task. +Do not use macros for literal constants. +Instead, if inside the body of a function, +use `constexpr` values, +and if at namespace scope, use +[`inline constexpr` variables](https://en.cppreference.com/w/cpp/language/inline) +(a C++17 feature). + +"Namespace" macros by starting them with the module name, e.g., `CUTLASS_`. +Macros and ONLY MACROS use all capital letters with underscores between words. +For example: + +```c++ +#define CUTLASS_MACROS_USE_ALL_CAPS inline __host__ __device__ +``` -Annotations for device code: -* `CUTLASS_HOST_DEVICE` for functions running on the host and the device -* `CUTLASS_DEVICE` for functions running on the device only +Header files such as +[cutlass/cutlass.h](../../include/cutlass/cutlass.h) +and +[cute/config.hpp](../../include/cutlass/cutlass.h) +offer macros for expressing compiler-dependent behavior. +These include -Loop unrolling: -* `CUTLASS_PRAGMA_UNROLL` for full unrolling of loops with constant trip counts -* `CUTLASS_PRAGMA_NO_UNROLL` to prevent unrolling +* replacements for `__device__` and/or `__host__` + annotations: -### #pragma once + * `CUTLASS_HOST_DEVICE` or `CUTE_HOST_DEVICE` + for functions that run on the host and the device, + + * `CUTLASS_DEVICE` or `CUTE_DEVICE` + for functions that run on the device only, and + + * `CUTE_HOST` + for functions that run on the host only; and + +* annotations to loop unrolling: + + * `CUTLASS_PRAGMA_UNROLL` or `CUTE_UNROLL` + for full unrolling of loops with constant trip counts, and + + * `CUTLASS_PRAGMA_NO_UNROLL` or `CUTE_NO_UNROLL` to prevent unrolling. + +#### Guard all headers with `#pragma once` Use `#pragma once` to guard all headers. -```c++ -/*! +### CUDA C++ style + +#### CUDA Built-in Variables + +Avoid direct access to CUDA built-in variables `threadIdx`, `blockIdx`, `blockDim`, and `gridDim` within +CUTLASS components except in special circumstances. -*/ +Using built-in global variables directly within resuable components necessitates that all components +use them consistently which may not be possible if CUTLASS components are used in other contexts. -#pragma once +Instead, components should accept a linear ID identifying threads, warps, and threadblocks from calling +code. The top-level kernel may then decide how to map threads, warps, and blocks to the problem it is +solving. -... -``` +#### Use CUTLASS's and CuTe's fundamental types and operations + +Use the +[fundamental types and operations](fundamental_types.md) +defined in CUTLASS consistently. +This contributes to a framework of interoperable, consistent components. +It reduces code duplication, which reduces build and test times. +It also saves developer effort. + +CUTLASS's fundamental types and operations include -### Source Line Length +* [Numeric types](fundamental_types.md#numeric-types) to represent numeric data in host and device code, and -Avoid lines longer than 100 characters. These typically wrap unfavorably when viewed in -Github's pretty printer. +* [functional.h](fundamental_types.md#functional) to perform numeric operations in generic code. +CUTLASS 3.0 uses CuTe components to represent data layouts and multidimensional arrays. +Please refer to the [CuTe Tutorial](./cute/00_quickstart.md) for details. +CuTe has replaced CUTLASS 2.x components such as +[Containers](fundamental_types.md#containers), +[Layouts](layout.md), and +[`TensorRef` and `TensorView`](layout.md#tensorref). # Copyright diff --git a/media/docs/quickstart.md b/media/docs/quickstart.md index ff13abf9c7..f0d4d8a311 100644 --- a/media/docs/quickstart.md +++ b/media/docs/quickstart.md @@ -7,9 +7,9 @@ ## Prerequisites CUTLASS requires: -- NVIDIA CUDA Toolkit (9.2 or later required, [11.1](https://developer.nvidia.com/cuda-toolkit) recommended) -- CMake 3.12+ -- host compiler supporting C++11 or greater (g++ 7.3.0 or Microsoft Visual Studio 2015 recommended) +- NVIDIA CUDA Toolkit (11.4 or later required, [12.0](https://developer.nvidia.com/cuda-toolkit) recommended) +- CMake 3.18+ +- host compiler supporting C++17 or greater (minimum g++ 7.5.0) - Python 3.6+ CUTLASS may be optionally compiled and linked with @@ -24,13 +24,13 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc $ mkdir build && cd build -$ cmake .. -DCUTLASS_NVCC_ARCHS=80 # compiles for NVIDIA Ampere GPU architecture +$ cmake .. -DCUTLASS_NVCC_ARCHS=90a # compiles for NVIDIA Hopper GPU architecture ``` If your goal is strictly to build only the CUTLASS Profiler and to minimize compilation time, we suggest executing the following CMake command in an empty `build/` directory. ```bash -$ cmake .. -DCUTLASS_NVCC_ARCHS=80 -DCUTLASS_ENABLE_TESTS=OFF -DCUTLASS_UNITY_BUILD_ENABLED=ON +$ cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_ENABLE_TESTS=OFF -DCUTLASS_UNITY_BUILD_ENABLED=ON ``` This reduces overall compilation time by excluding unit tests and enabling the unit build. @@ -39,13 +39,13 @@ You may reduce build times by compiling only certain operations by setting the ` executed from an empty `build/` directory. This only compiles 2-D convolution kernels. ```bash -$ cmake .. -DCUTLASS_NVCC_ARCHS=80 -DCUTLASS_LIBRARY_OPERATIONS=conv2d +$ cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_LIBRARY_OPERATIONS=conv2d ``` -You may also filter kernels by name by supplying a filter string with flag `CUTLASS_LIBRARY_KERNELS`. +You may also filter kernels by name by supplying a filter string with flag `CUTLASS_LIBRARY_KERNELS`. For example the below command selects only CUTLASS-3 kernels. ```bash -$ cmake .. -DCUTLASS_NVCC_ARCHS=80 -DCUTLASS_LIBRARY_KERNELS=s16816gemm,s16816fprop*128x128 +$ cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_LIBRARY_KERNELS=cutlass3x* ``` See more examples on selectively compiling CUTLASS GEMM and convolution kernels [here](quickstart.md#example-cmake-commands). @@ -180,6 +180,10 @@ To minimize compilation time, specific GPU architectures can be enabled via the selected by [CUDA Compute Capability.](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities) **NVIDIA Ampere Architecture.** +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS=90a # compiles for NVIDIA Hopper GPU architecture +``` + ```bash $ cmake .. -DCUTLASS_NVCC_ARCHS=80 # compiles for NVIDIA Ampere GPU architecture ``` @@ -204,32 +208,10 @@ $ cmake .. -DCUTLASS_NVCC_ARCHS="60;61" # compiles for NVIDIA Pascal GP $ cmake .. -DCUTLASS_NVCC_ARCHS="50;53" # compiles for NVIDIA Maxwell GPU architecture ``` -## Clang - -For experimental purposes, CUTLASS has been verified to compile with the following versions of Clang and CUDA. - -* [clang 8.0](https://github.com/llvm/llvm-project/releases/download/llvmorg-8.0.1/clang+llvm-8.0.1-amd64-unknown-freebsd11.tar.xz) using the -[CUDA 10.0 Toolkit](https://developer.nvidia.com/cuda-10.0-download-archive). -* [clang release/13.x](https://github.com/llvm/llvm-project/tree/release/13.x) using [CUDA 11.4](https://developer.nvidia.com/cuda-toolkit-archive) - -At this time, compiling with clang enables the CUTLASS SIMT GEMM kernels (sgemm, dgemm, hgemm, igemm) -but does not enable TensorCores. - -```bash -$ mkdir build && cd build - -$ cmake -DCUDA_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ .. -# Add -DCMAKE_CXX_FLAGS=-D__NV_NO_HOST_COMPILER_CHECK=1 -DCMAKE_CUDA_FLAGS=-D__NV_NO_HOST_COMPILER_CHECK=1 if compiler -# checks fail during CMake configuration. - -$ make test_unit -j -``` - - ## Using CUTLASS within other applications Applications should list [`/include`](/include) within their include paths. They must be -compiled as C++11 or greater. +compiled as C++17 or greater. **Example:** print the contents of a variable storing half-precision data. ```c++ @@ -345,6 +327,136 @@ Note, the above could be simplified as follows using helper methods defined in ` }); ``` +## Launching a GEMM kernel using CUTLASS 3.0 or newer + +**Example:** launch a mixed-precision GEMM targeting Hopper Tensor Cores. + +```c++ +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" + +using namespace cute; + +int main(int argc, char const **args) { + + // A matrix configuration + using ElementA = cutlass::half_t; // Element type for A matrix operand + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using ElementB = cutlass::half_t; // Element type for B matrix operand + using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementC = cutlass::half_t; // Element type for C and D matrix operands + using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TilesShape = Shape<_128,_128,_64>; // Threadblock-level tile size + using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster + using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TilesShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + Gemm gemm_op; + cutlass::Status status; + + // + // Define the problem size + // + + int M = 512; + int N = 256; + int K = 128; + + float alpha = 1.25f; + float beta = -1.25f; + + // + // Allocate device memory + // + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + + stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, Int<1>{})); + stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{})); + stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{})); + stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{})); + + block_A.reset(M * K); + block_B.reset(K * N); + block_C.reset(M * N); + block_D.reset(M * N); + + // + // Launch GEMM on the device + // + + status = gemm_op({ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + block_A.get(), + stride_A, + block_B.get(), + stride_B, + {block_C.get(), stride_C, block_D.get(), stride_D, {alpha, beta}} + }); + + if (status != cutlass::Status::kSuccess) { + return -1; + } + + return 0; +} +``` + # CUTLASS Library The [CUTLASS Library](/tools/library) defines an API for managing and executing collections of compiled diff --git a/media/docs/terminology.md b/media/docs/terminology.md index f2d8b6838c..e0f04790a3 100644 --- a/media/docs/terminology.md +++ b/media/docs/terminology.md @@ -4,10 +4,10 @@ # CUTLASS Terminology -`AlignedBuffer`: statically sized array type; union-safe, no construction guarantee for elements +**cute::Layout**: A `cute::Layout` vocabulary type composed of the hierarchical `cute::Shape` and `cute::Stride` +tuples that is used throughout CUTLASS 3.0 to represent and manipulate thread and data layouts. More details are included in the [CuTe specific tensor type documentation](/media/docs/cute/03_tensor.md). -`Array`: container for holding numeric types - handles bit packing for small numeric types (e.g. int4_t, uint4_t, bin1_t) - `sizeof(Array)` - gives expected value in units of bytes with minimum storage of `1 B`: (sizeof_bits::value * N) / 8 +**cute::Tensor**: A pointer backed by a `cute::Layout` used to represent a tensor. More details are included in the [CuTe specific tensor type documentation](/media/docs/cute/03_tensor.md). **Capacity**: (scalar) physical number of elements in memory required to store a multidimensional object; expressed as the type's LongIndex type - example: the capacity of a column-major matrix is `lda * N` @@ -28,8 +28,6 @@ **Numeric Type**: a CUTLASS data type used to represent real-valued quantities; is trivially copyable. -**Operator**: an object performing a computation on matrix or tensor objects. May be further refined by scope within the execution model hierarchy. - **Pitch Linear**: linear memory allocation obtained from a user-defined 2-D size, which specifies the contiguous and strided dimensions of a tile. @@ -61,17 +59,27 @@ contiguous and strided dimensions of a tile. **Tile**: partitions of a tensor that have constant extents and layout known at compile time -**Tile Iterator**: abstraction for accessing and traversing a sequence of tiles in a tensor; CUTLASS specifies - [formal concepts for tile iterators](tile_iterator_concept.md) - -**Thread Map**: abstraction for defining how threads are mapped to a given tile. - **Trait**: characteristics of a fully-specialized type, typically used in metaprogramming reflection **View**: an object containing references to a data structure that it does not own; typically, construction of views is lightweight **Warp**: a collection of hardware threads executing in lock-step; warp-level operations typically rely on cooperation among the threads within the warp +`AlignedBuffer`: statically sized array type; union-safe, no construction guarantee for elements + +`Array`: container for holding numeric types - handles bit packing for small numeric types (e.g. int4_t, uint4_t, bin1_t) + `sizeof(Array)` - gives expected value in units of bytes with minimum storage of `1 B`: (sizeof_bits::value * N) / 8 + +**Operator**: an object performing a computation on matrix or tensor objects. May be further refined by scope within the execution model hierarchy. Deprecated starting CUTLASS 3.0, +replaced by [MMA and Copy atoms from CuTe](/media/docs/cute/0t_mma_atom.md). + +**Tile Iterator**: abstraction for accessing and traversing a sequence of tiles in a tensor; CUTLASS specifies + [formal concepts for tile iterators](tile_iterator_concept.md). Deprecated starting CUTLASS 3.0. + Replaced by `cute::Layout` in equivalent usage scenarios to represent data tensors. + +**Thread Map**: abstraction for defining how threads are mapped to a given tile. Deprecated starting CUTLASS 3.0. + Replaced by `cute::Layout` in equivalent usage scenarios to represent thread tensors. + # Copyright Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/media/docs/tile_iterator_concept.md b/media/docs/tile_iterator_concept.md index 3c20797fd9..efff36131d 100644 --- a/media/docs/tile_iterator_concept.md +++ b/media/docs/tile_iterator_concept.md @@ -4,9 +4,15 @@ # Tile Iterator Concepts +Note: CUTLASS 3.0 deprecates all tile access iterators in favour of CuTe's single +vocabulary type `cute::Tensor`, which is parameterized on `cute::Layout`. +`cute::Tensor`s can therefore be manipulated with the same layout algebra as all CuTe layouts. +This removes the need for bespoke types that encapsulate iterator properties. +The following text thus only applies to legacy CUTLASS 2.x API and related types. + CUTLASS 2.x implements generic algorithms on tiles of matrix or tensors of constant size. These may be considered as partitions of tensors of infinite size, with a range of partitions accessible -by _tile iterators_. +by _tile iterators_. Various data structures may make operations such as random access to tiles inexpensive, while data structures may not offer random access at all. For example, iterating over a linked @@ -14,7 +20,9 @@ list of matrices requires sequential traversal. Algorithms implemented in terms should require only the minimum set of operators be defined for tile iterators. This document describes a set of C++ concepts which may be used to define tile iterators used -by CUTLASS algorithms. Each concept specifies members and type definitions that a tile iterator +by CUTLASS algorithms. ("Concept" here does not refer to a C++20 concept that uses the `concept` keyword. +Rather, it refers to a set of requirements on a type.) +Each concept specifies members and type definitions that a tile iterator must implement. Frequently, a tile iterator implements several concepts, and its members are the union of the members from each individual concept. These definitions were inspired by [Boost "New style" iterator concepts](https://www.boost.org/doc/libs/1_40_0/libs/iterator/doc/new-iter-concepts.html). @@ -23,7 +31,6 @@ The set of all possible combinations of these concepts is quite large, however m templates can be described by one of several combinations. The section Frequently Used Tile Iterator Concepts describes several common interfaces used throughout CUTLASS. - ## Definitions **_Base Tile Iterator Concept_.** All tile iterators must describe an _Element_ type as well as a _Shape_. diff --git a/media/docs/utilities.md b/media/docs/utilities.md index 66e71cad7b..c464f2007d 100644 --- a/media/docs/utilities.md +++ b/media/docs/utilities.md @@ -2,6 +2,13 @@ [README](/README.md#documentation) > **CUTLASS Utilities** +Note: This document discusses utilities commonly used with code that targets CUTLASS 2.x. +Although CUTLASS 3.0's primary entry point APIs do not transact in these `cutlass::*` tensor types anymore, +users can still find them convenient for managing allocations with trivial affine layouts. +For more advanced host side tensor management, [`cute::Tensor`](/media/docs/cute/03_tensor.md)s +can be used on either host or device for any memory space and full expressive power of +[`cute::Layout`](/media/docs/cute/01_layout.md)s. + # CUTLASS Utilities CUTLASS utilities are additional template classes that facilitate recurring tasks. These are diff --git a/media/images/cute/HMMA.8x8x4.NT.png b/media/images/cute/HMMA.8x8x4.NT.png new file mode 100644 index 0000000000..adedbac03c Binary files /dev/null and b/media/images/cute/HMMA.8x8x4.NT.png differ diff --git a/media/images/cute/HMMA.8x8x4.quadpair.AB.png b/media/images/cute/HMMA.8x8x4.quadpair.AB.png new file mode 100644 index 0000000000..2b04c7328a Binary files /dev/null and b/media/images/cute/HMMA.8x8x4.quadpair.AB.png differ diff --git a/media/images/cute/HMMA.8x8x4.quadpair.C.png b/media/images/cute/HMMA.8x8x4.quadpair.C.png new file mode 100644 index 0000000000..2e255e420d Binary files /dev/null and b/media/images/cute/HMMA.8x8x4.quadpair.C.png differ diff --git a/media/images/cute/gmma_coremat_cd_fp16.png b/media/images/cute/gmma_coremat_cd_fp16.png new file mode 100644 index 0000000000..f84e0d249e Binary files /dev/null and b/media/images/cute/gmma_coremat_cd_fp16.png differ diff --git a/media/images/cute/gmma_wg_n_slice.png b/media/images/cute/gmma_wg_n_slice.png new file mode 100644 index 0000000000..6fa03c0e31 Binary files /dev/null and b/media/images/cute/gmma_wg_n_slice.png differ diff --git a/media/images/cute/logical_divide-and-zipped_divide-2.png b/media/images/cute/logical_divide-and-zipped_divide-2.png new file mode 100755 index 0000000000..c1c29a4a1d Binary files /dev/null and b/media/images/cute/logical_divide-and-zipped_divide-2.png differ diff --git a/media/images/cute/logical_divide-and-zipped_divide.png b/media/images/cute/logical_divide-and-zipped_divide.png new file mode 100755 index 0000000000..471649f539 Binary files /dev/null and b/media/images/cute/logical_divide-and-zipped_divide.png differ diff --git a/media/images/cutlass-3.0-gemm-peak-performance.png b/media/images/cutlass-3.0-gemm-peak-performance.png new file mode 100644 index 0000000000..4e92a56be6 Binary files /dev/null and b/media/images/cutlass-3.0-gemm-peak-performance.png differ diff --git a/media/images/cutlass-reduction-in-named-iterators.png b/media/images/cutlass-reduction-in-named-iterators.png new file mode 100644 index 0000000000..446fa9bf82 Binary files /dev/null and b/media/images/cutlass-reduction-in-named-iterators.png differ diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 445edc64ab..c4e0634cb1 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -28,6 +28,8 @@ include(CTest) +set(CUTLASS_UNIT_TEST_COMMON_DIR ${CMAKE_CURRENT_LIST_DIR}/common) + cutlass_add_library( cutlass_test_unit_infra OBJECT @@ -42,6 +44,7 @@ target_link_libraries( $<$:nvidia::cublas> gtest cudart + cuda_driver ) cutlass_add_library( @@ -69,6 +72,12 @@ function(cutlass_test_unit_add_executable NAME) target_compile_definitions(${NAME} PUBLIC CUTLASS_TARGET_NAME="${NAME}") + target_include_directories( + ${NAME} + PRIVATE + ${CUTLASS_UNIT_TEST_COMMON_DIR} + ) + target_link_libraries( ${NAME} PRIVATE @@ -76,6 +85,10 @@ function(cutlass_test_unit_add_executable NAME) cutlass_test_unit_infra_lib ) + if (CUTLASS_ENABLE_OPENMP_TESTS AND OpenMP_CXX_FOUND) + target_link_libraries(${NAME} PRIVATE OpenMP::OpenMP_CXX) + endif() + string(REGEX REPLACE cutlass_ "" NAME_STEM ${NAME}) set(RESULT_CACHE_FILE "${CUTLASS_TEST_UNIT_RESULTS_CACHE_DIR}/cached_results_${NAME}.txt") @@ -99,6 +112,7 @@ add_custom_target(test_unit) set(SUBDIRS core + cute gemm conv layout @@ -106,6 +120,7 @@ set(SUBDIRS epilogue reduction util + pipeline ) if(TARGET nvidia::nvrtc AND TARGET nvidia::cuda_driver) diff --git a/test/unit/common/cutlass_unit_test.h b/test/unit/common/cutlass_unit_test.h index b631274314..8843e40b70 100644 --- a/test/unit/common/cutlass_unit_test.h +++ b/test/unit/common/cutlass_unit_test.h @@ -39,6 +39,17 @@ #include #include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Gets a CUDA device +cudaDeviceProp GetCudaDevice(); + +/// Prints device properties +std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &device); + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Sets flags for Unit test @@ -52,7 +63,6 @@ int CutlassUnitTestProblemCount(); ///////////////////////////////////////////////////////////////////////////////////////////////// - // active test macro #define CUTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__ @@ -78,3 +88,15 @@ int CutlassUnitTestProblemCount(); #if !defined(CUTLASS_TEST_UNIT_ENABLE_WARNINGS) #define CUTLASS_TEST_UNIT_ENABLE_WARNINGS false #endif + +#if (__CUDACC_VER_MAJOR__ >= 12) + #define CUDA_12_0_SM90_FEATURES_SUPPORTED true +#else + #define CUDA_12_0_SM90_FEATURES_SUPPORTED false +#endif + +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/common/filter_architecture.cpp b/test/unit/common/filter_architecture.cpp index 55986ddedd..553915f349 100644 --- a/test/unit/common/filter_architecture.cpp +++ b/test/unit/common/filter_architecture.cpp @@ -35,9 +35,49 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Gets a CUDA device +cudaDeviceProp GetCudaDevice() { + + cudaError_t err; + + int cudaDeviceId; + err = cudaGetDevice(&cudaDeviceId); + if (cudaSuccess != err) { + std::cerr << "*** Error: Could not detect active GPU device ID" + << " [" << cudaGetErrorString(err) << "]" << std::endl; + exit(1); + } + + cudaDeviceProp deviceProperties; + err = cudaGetDeviceProperties(&deviceProperties, cudaDeviceId); + + return deviceProperties; +} + +/// Prints device properties +std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &deviceProperties) { + + int deviceMajorMinor = deviceProperties.major * 10 + deviceProperties.minor; + if (deviceMajorMinor) { + int32_t clock_MHz = deviceProperties.clockRate / 1000; + out << "GPU(compute_" + << deviceMajorMinor << ", " + << deviceProperties.multiProcessorCount << " SMs @ " << clock_MHz << " MHz)"; + } + else { + out << "No CUDA device."; + } + + return out; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Sets flags for Unit test void FilterArchitecture() { // Default flags can be overwritten by --gtest_filter from commandline + + int const kMaxDevice = 999; + cudaError_t err; int cudaDeviceId; @@ -57,7 +97,6 @@ void FilterArchitecture() { } int deviceMajorMinor = deviceProperties.major * 10 + deviceProperties.minor; - int const kMaxDevice = 999; // Defines text filters for each GEMM kernel based on minimum supported compute capability struct { @@ -78,7 +117,7 @@ void FilterArchitecture() { { "SM70*", 70, 75}, { "SM75*", 75, kMaxDevice}, { "SM80*", 80, kMaxDevice}, - { "SM90*", 90, kMaxDevice}, + { "SM90*", 90, 90 }, { 0, 0, false } }; diff --git a/test/unit/conv/device/conv2d_testbed_interleaved.h b/test/unit/conv/device/conv2d_testbed_interleaved.h index 352c19b8aa..79f00d156b 100644 --- a/test/unit/conv/device/conv2d_testbed_interleaved.h +++ b/test/unit/conv/device/conv2d_testbed_interleaved.h @@ -186,6 +186,34 @@ class InterleavedTestbedConv2d { tensor_D_reference.sync_device(); } + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage)); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerMultiprocessor < smem_size) { + return false; + } + + return true; + } + /// Executes one test bool run( cutlass::conv::Conv2dProblemSize const &problem_size, @@ -193,6 +221,14 @@ class InterleavedTestbedConv2d { ElementCompute alpha = ElementCompute(1), ElementCompute beta = ElementCompute(0)) { + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + #if 0 //display conv2d problem size for debugging std::cout << problem_size << std::endl << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl diff --git a/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu b/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu index c30ca90c89..8efc73e1d9 100644 --- a/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu +++ b/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu @@ -328,6 +328,7 @@ TEST( DepthwiseFpropProblemSizes_filter5x5())); } +#if 0 //////////////////////////////////////////////////////////////////////////////// TEST( SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_Optimized_f16nhwc_f16nhwc_f16nhwc_simt_f16, @@ -424,3 +425,5 @@ TEST( EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( DepthwiseFpropProblemSizes_filter5x37())); } +#endif + diff --git a/test/unit/core/numeric_conversion.cu b/test/unit/core/numeric_conversion.cu index cd43bf63a8..8d7a29686f 100644 --- a/test/unit/core/numeric_conversion.cu +++ b/test/unit/core/numeric_conversion.cu @@ -152,9 +152,7 @@ TEST(NumericConversion, f32_to_fe5m2_rn_array) { int const kN = 27; using Source = float; using Destination = cutlass::float_e5m2_t; - test::core::kernel::run_test(); - } TEST(NumericConversion, f16_to_fe4m3_rn) { @@ -250,16 +248,19 @@ TEST(NumericConversion, fe4m3_to_f32_rn) { test::core::kernel::run_test(); } -TEST(NumericConversion, fe4m3_to_f32_array) { - int const kN = 27; - using Source = cutlass::float_e4m3_t; - using Destination = float; +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(NumericConversion, f32x8_to_s8x8_rn) { + + int const kN = 8; + using Source = float; + using Destination = int8_t; test::core::kernel::run_test(); } -TEST(NumericConversion, fe5m2_to_f32_rn) { - int const kN = 1; - using Source = cutlass::float_e5m2_t; +TEST(NumericConversion, fe4m3_to_f32_array) { + int const kN = 27; + using Source = cutlass::float_e4m3_t; using Destination = float; test::core::kernel::run_test(); } @@ -328,35 +329,3 @@ TEST(NumericConversion, fe5m2_to_bf16_array) { } ///////////////////////////////////////////////////////////////////////////////////////////////// - -TEST(NumericConversion, f32x8_to_s8x8_rn) { - - int const kN = 8; - using Source = float; - using Destination = int8_t; - - dim3 grid(1, 1); - dim3 block(1, 1); - - cutlass::HostTensor destination({1, kN}); - cutlass::HostTensor source({1, kN}); - - for (int i = 0; i < kN; ++i) { - source.host_data()[i] = float(i); - } - - source.sync_device(); - - test::core::kernel::convert<<< grid, block >>>( - reinterpret_cast *>(destination.device_data()), - reinterpret_cast const *>(source.device_data()) - ); - - destination.sync_host(); - - for (int i = 0; i < kN; ++i) { - EXPECT_TRUE(float(destination.host_data()[i]) == source.host_data()[i]); - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/cute/CMakeLists.txt b/test/unit/cute/CMakeLists.txt new file mode 100644 index 0000000000..43a7bd00b1 --- /dev/null +++ b/test/unit/cute/CMakeLists.txt @@ -0,0 +1,50 @@ +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +add_subdirectory(core) +add_subdirectory(ampere) +add_subdirectory(hopper) +add_subdirectory(layout) + +add_custom_target( + cutlass_test_unit_cute + DEPENDS + cutlass_test_unit_cute_layout + cutlass_test_unit_cute_core + cutlass_test_unit_cute_ampere + cutlass_test_unit_cute_hopper + ) + +add_custom_target( + test_unit_cute + DEPENDS + test_unit_cute_layout + test_unit_cute_core + test_unit_cute_ampere + test_unit_cute_hopper + ) diff --git a/test/unit/cute/ampere/CMakeLists.txt b/test/unit/cute/ampere/CMakeLists.txt new file mode 100644 index 0000000000..91a5f5f0fa --- /dev/null +++ b/test/unit/cute/ampere/CMakeLists.txt @@ -0,0 +1,33 @@ +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_ampere + cp_async.cu + ldsm.cu +) diff --git a/test/unit/cute/ampere/cp_async.cu b/test/unit/cute/ampere/cp_async.cu new file mode 100644 index 0000000000..7a80a518cd --- /dev/null +++ b/test/unit/cute/ampere/cp_async.cu @@ -0,0 +1,104 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +using namespace cute; + +__global__ void +test(double const* g_in, double* g_out) +{ + extern __shared__ double smem[]; + + smem[threadIdx.x] = g_in[threadIdx.x]; + + __syncthreads(); + + g_out[threadIdx.x] = 2 * smem[threadIdx.x]; +} + +__global__ void +test2(double const* g_in, double* g_out) +{ + using namespace cute; + + extern __shared__ double smem[]; + + auto s_tensor = make_tensor(make_smem_ptr(smem + threadIdx.x), Int<1>{}); + auto g_tensor = make_tensor(make_gmem_ptr(g_in + threadIdx.x), Int<1>{}); + + copy(g_tensor, s_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + g_out[threadIdx.x] = 2 * smem[threadIdx.x]; +} + +TEST(SM80_CuTe_Ampere, CpAsync) +{ + constexpr int count = 32; + thrust::host_vector h_in(count); + for (int i = 0; i < count; ++i) { + h_in[i] = double(i); + } + + thrust::device_vector d_in(h_in); + + thrust::device_vector d_out(count, -1); + test<<<1, count, sizeof(double) * count>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data())); + thrust::host_vector h_result = d_out; + + thrust::device_vector d_out_cp_async(count, -2); + test2<<<1, count, sizeof(double) * count>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out_cp_async.data())); + thrust::host_vector h_result_cp_async = d_out_cp_async; + + for (int i = 0; i < count; ++i) { + EXPECT_EQ(h_result[i], h_result_cp_async[i]); + } +} diff --git a/test/unit/cute/ampere/ldsm.cu b/test/unit/cute/ampere/ldsm.cu new file mode 100644 index 0000000000..15ec44b33a --- /dev/null +++ b/test/unit/cute/ampere/ldsm.cu @@ -0,0 +1,431 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include +#include + +#include + +#include + + +using namespace cute; + +template +__global__ void +ldsm_test_device(uint16_t* g_in, uint16_t* g_out) +{ + constexpr int count = sizeof(T) / 4; + int tid = threadIdx.x; + int stride = blockDim.x; + + // load input gmem -> smem + __shared__ uint32_t smem[32 * count]; + for (int i = 0; i < count; ++i) { + smem[tid + (stride * i)] = reinterpret_cast(g_in)[tid + (stride * i)]; + } + + __syncthreads(); + + uint32_t reg[count]; + for (int i = 0; i < count; ++i) { + reg[i] = 0; + } + + // load smem -> rmem using LDSM + uint128_t* smem_ptr = reinterpret_cast(smem) + tid; + T* rmem_ptr = reinterpret_cast(reg); + cute::copy_ldsm(smem_ptr, rmem_ptr); + + // store output rmem -> gmem + for (int i = 0; i < count; ++i) { + reinterpret_cast(g_out)[tid + (stride * i)] = reg[i]; + } +} + +template +__global__ void +ldsm_test_device_cute(uint16_t* g_in, uint16_t* g_out, + TiledCopy tiled_copy, SmemLayout smem_layout) +{ + using namespace cute; + + __shared__ uint16_t smem[size(smem_layout)]; + + auto t_g_in = make_tensor(make_gmem_ptr(g_in), smem_layout); + auto t_g_out = make_tensor(make_gmem_ptr(g_out), smem_layout); + auto t_smem = make_tensor(make_smem_ptr(smem), smem_layout); + + int tid = threadIdx.x; + + // Load input gmem -> smem + for (int i = tid; i < size(t_smem); i += size(tiled_copy)) { + t_smem(i) = t_g_in(i); + } + + __syncthreads(); + + auto thr_copy = tiled_copy.get_thread_slice(tid); + + auto tXsX = thr_copy.partition_S(t_smem); // (V,M,N) + auto tXgX = thr_copy.partition_D(t_g_out); // (V,M,N) + + auto tXrX = make_tensor(shape(tXgX)); // (V,M,N) + clear(tXrX); // Just to make sure + +/* + if (thread0()) { + print("tXsX: " ); print(tXsX.layout()); print("\n"); + print("tXgX: " ); print(tXgX.layout()); print("\n"); + print("tXrX: " ); print(tXrX.layout()); print("\n"); + } +*/ + + // Copy smem -> rmem via tiled_copy (LDSM, LDS) + copy(tiled_copy, tXsX, tXrX); + + // Output rmem -> gmem + copy(tXrX, tXgX); +} + + +TEST(SM80_CuTe_Ampere, Ldsm) +{ + constexpr int count = 1024; + + thrust::host_vector h_in(count); + for (int i = 0; i < count; ++i) { + h_in[i] = uint16_t(i); + } + thrust::device_vector d_in = h_in; + + // + // LDSM 1x (32b) + // + + { + thrust::device_vector d_out(count); + ldsm_test_device<<<1, 32>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data())); + thrust::host_vector h_out = d_out; + for (int i = 0; i < 32; ++i) { + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("LDSM 1x ldsm_test_device SUCCESS\n"); + } + + // + // LDSM 2x (64b) + // + + { + thrust::device_vector d_out(count); + ldsm_test_device<<<1, 32>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data())); + thrust::host_vector h_out = d_out; + for (int i = 0; i < 64; ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("LDSM 2x ldsm_test_device SUCCESS\n"); + } + + // + // LDSM 4x (128b) + // + + { + thrust::device_vector d_out(count); + ldsm_test_device<<<1, 32>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data())); + thrust::host_vector h_out = d_out; + for (int i = 0; i < 128; ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("LDSM 4x ldsm_test_device SUCCESS\n"); + } + + // + // CuTe LDSM + // + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout>, + Stride< _2,Stride<_1,_64>>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x1_LDSM_N SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout>, + Stride< _2,Stride<_1,_64>>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x2_LDSM_N SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout>, + Stride< _2,Stride<_1,_64>>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x4_LDSM_N SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout>, + Stride< _2,Stride<_1,_64>>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom, uint16_t>{}, + Layout>{}, + Layout>{}); + + ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i] , h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x8 interleaved LDS.U16 SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout, + Stride< _1,_32>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x32 U32x1_LDSM_N SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout, + Stride< _1,_32>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x32 U32x2_LDSM_N SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout, + Stride< _1,_32>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x32 U32x4_LDSM_N SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout, + Stride< _1,_32>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom, uint16_t>{}, + Layout>{}, + Layout>{}); + + ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x32 LDS.U16 SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout, + Stride<_32, _1>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x32 U16x2_LDSM_T SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout, + Stride<_32, _1>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x32 U16x4_LDSM_T SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout, + Stride<_32, _1>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x32 U16x8_LDSM_T SUCCESS\n"); + } + + CUTLASS_TRACE_HOST("PASS"); +} diff --git a/test/unit/cute/core/CMakeLists.txt b/test/unit/cute/core/CMakeLists.txt new file mode 100644 index 0000000000..e8e3555aed --- /dev/null +++ b/test/unit/cute/core/CMakeLists.txt @@ -0,0 +1,44 @@ +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_core + + bitfield.cpp + coalesce.cpp + compare.cpp + complement.cpp + composition.cpp + inverse_left.cpp + inverse_right.cpp + logical_divide.cpp + logical_product.cpp + mixedbits.cpp + transform.cpp + tuple.cpp +) diff --git a/test/unit/cute/core/bitfield.cpp b/test/unit/cute/core/bitfield.cpp new file mode 100644 index 0000000000..94b139e385 --- /dev/null +++ b/test/unit/cute/core/bitfield.cpp @@ -0,0 +1,84 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace cute; + +TEST(CuTe_core, Bitfield) +{ + for_each(make_int_range<1,65>{}, [&](auto NumBits) { + for_each(make_int_range<0,129>{}, [&](auto BitStart) { + + using BF = bit_field; + +#if 0 + printf("bit_field<%d,%d>:\n", decltype(BitStart)::value, decltype(NumBits)::value); + printf(" value_type_bits : %d\n", BF::value_type_bits); + printf(" storage_type_bits: %d\n", BF::storage_type_bits); + printf(" N : %d\n", BF::N); + printf(" idx : %d\n", BF::idx); + printf(" bit_lo : %d\n", BF::bit_lo); + printf(" bit_hi : %d\n", BF::bit_hi); + printf(" mask : 0x%lx\n", uint64_t(BF::mask)); + printf(" mask_lo : 0x%lx\n", uint64_t(BF::mask_lo)); + printf(" mask_hi : 0x%lx\n", uint64_t(BF::mask_hi)); +#endif + + // Test + uint64_t v = decltype(NumBits)::value == 64 ? uint64_t(-1) : ((uint64_t(1) << NumBits) - 1); + + BF bf{}; + bf = v; + EXPECT_EQ(v, uint64_t(bf)); + }); + }); + + for_each(make_int_range<0,129>{}, [&](auto BitStart) { + + using BF = bit_field; + + BF bf{}; + bf = 3.14f; + EXPECT_EQ(3.14f, float(bf)); + }); + +} diff --git a/test/unit/cute/core/coalesce.cpp b/test/unit/cute/core/coalesce.cpp new file mode 100644 index 0000000000..2fdeb7c35a --- /dev/null +++ b/test/unit/cute/core/coalesce.cpp @@ -0,0 +1,182 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + + +#include + +using namespace cute; + +template +void +test_coalesce(Layout const& layout) +{ + auto coalesce_layout = coalesce(layout); + + CUTLASS_TRACE_HOST(shape (layout) << " => " << shape (coalesce_layout)); + CUTLASS_TRACE_HOST(stride(layout) << " " << stride(coalesce_layout)); + + CUTE_STATIC_ASSERT_V(depth(coalesce_layout) <= Int<1>{}); + + ASSERT_EQ(size(coalesce_layout), size(layout)); + + for (int i = 0; i < size(layout); ++i) { + EXPECT_EQ(coalesce_layout(i), layout(i)); + } +} + +TEST(CuTe_core, Coalesce) +{ + { + auto layout = make_layout(Int<1>{}, Int<0>{}); + + test_coalesce(layout); + } + + { + auto layout = make_layout(Int<1>{}, Int<1>{}); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape(Int<2>{}, Int<4>{})); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape(Int<2>{}, Int<4>{}, Int<6>{})); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape (Int<2>{}, Int<1>{}, Int<6>{}), + make_stride(Int<1>{}, Int<6>{}, Int<2>{})); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape (Int<2>{}, Int<1>{}, Int<6>{}), + make_stride(Int<1>{}, 7, Int<2>{})); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape (Int<2>{}, Int<1>{}, Int<6>{}), + make_stride(Int<4>{}, 7, Int<8>{})); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape(2, Int<4>{}, Int<6>{})); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape(Int<2>{}, 4, Int<6>{})); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape(Int<2>{}, Int<4>{}, 6)); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape(Int<2>{}, Int<4>{}), GenRowMajor{}); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape(Int<2>{}, Int<4>{}, Int<6>{}), GenRowMajor{}); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape(2, Int<4>{}, Int<6>{}), GenRowMajor{}); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape(Int<2>{}, 4, Int<6>{}), GenRowMajor{}); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape(Int<2>{}, Int<4>{}, 6), GenRowMajor{}); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape(Int<2>{}, Int<1>{}, Int<3>{}), GenRowMajor{}); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape(Int<2>{}, 1, Int<3>{}), GenRowMajor{}); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape(Int<2>{}, 1, Int<3>{}), make_stride(Int<2>{}, 4, Int<4>{})); + + test_coalesce(layout); + } + + { + auto layout = make_layout(make_shape(Int<2>{}, 1, Int<3>{}), make_stride(Int<2>{}, Int<0>{}, Int<4>{})); + + test_coalesce(layout); + } + + { + auto layout = Layout,Shape<_2, _2>>, + Stride,Stride<_8,_32>>>{}; + + test_coalesce(layout); + } +} diff --git a/test/unit/cute/core/compare.cpp b/test/unit/cute/core/compare.cpp new file mode 100644 index 0000000000..5e0c3eecc6 --- /dev/null +++ b/test/unit/cute/core/compare.cpp @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include + +TEST(CuTe_core, Compare_simple_2d_GenColMajor) +{ + using namespace cute; + + // Simple 2D layout + auto layout = make_layout(make_shape(Int<3>{}, Int<5>{}), GenColMajor{}); + CUTLASS_TRACE_HOST("Layout: " << layout); + + for (int i = 0; i < size(layout); ++i) { + auto coord_i = layout.get_hier_coord(i); + + CUTLASS_TRACE_HOST(i << ": " << coord_i); + + EXPECT_TRUE(elem_less(coord_i, shape(layout))); + + for (int j = 0; j < size(layout); ++j) { + auto coord_j = layout.get_hier_coord(j); + CUTLASS_TRACE_HOST(" " << j << ": " << coord_j); + EXPECT_TRUE(elem_less(coord_j, shape(layout))); + + EXPECT_EQ((i < j), colex_less(coord_i,coord_j)); + } + } +} + + +TEST(CuTe_core, Compare_simple_2d_GenRowMajor) +{ + using namespace cute; + + auto layout = make_layout(make_shape(Int<3>{}, Int<5>{}), GenRowMajor{}); + CUTLASS_TRACE_HOST("Layout: " << layout); + + for (int i = 0; i < size(layout); ++i) { + auto coord_i = layout.get_hier_coord(i); + CUTLASS_TRACE_HOST(i << ": " << coord_i); + EXPECT_TRUE(elem_less(coord_i, shape(layout))); + + for (int j = 0; j < size(layout); ++j) { + auto coord_j = layout.get_hier_coord(j); + EXPECT_TRUE(elem_less(coord_j, shape(layout))); + + EXPECT_EQ((i < j), lex_less(coord_i,coord_j)); + } + } +} + + +TEST(CuTe_core, Compare_simple_3d_GenColMajor) +{ + using namespace cute; + + auto layout = make_layout(make_shape(Int<2>{}, Int<3>{}, Int<5>{}), GenColMajor{}); + CUTLASS_TRACE_HOST("Layout: " << layout); + + for (int i = 0; i < size(layout); ++i) { + auto coord_i = layout.get_hier_coord(i); + CUTLASS_TRACE_HOST(i << ": " << coord_i); + EXPECT_TRUE(elem_less(coord_i, shape(layout))); + + for (int j = 0; j < size(layout); ++j) { + auto coord_j = layout.get_hier_coord(j); + EXPECT_TRUE(elem_less(coord_j, shape(layout))); + + EXPECT_EQ((i < j), colex_less(coord_i,coord_j)); + } + } +} + + +TEST(CuTe_core, Compare_simple_3d_GenRowMajor) +{ + using namespace cute; + + auto layout = make_layout(make_shape(Int<2>{}, Int<3>{}, Int<5>{}), GenRowMajor{}); + CUTLASS_TRACE_HOST("Layout: " << layout); + + for (int i = 0; i < size(layout); ++i) { + auto coord_i = layout.get_hier_coord(i); + CUTLASS_TRACE_HOST(i << ": " << coord_i); + EXPECT_TRUE(elem_less(coord_i, shape(layout))); + + for (int j = 0; j < size(layout); ++j) { + auto coord_j = layout.get_hier_coord(j); + EXPECT_TRUE(elem_less(coord_j, shape(layout))); + + EXPECT_EQ((i < j), lex_less(coord_i,coord_j)); + } + } +} + + +TEST(CuTe_core, Compare_hierarchical_3d_GenColMajor) +{ + using namespace cute; + + auto layout = make_layout(Shape,Shape<_5,_2,_2>>{}, GenColMajor{}); + CUTLASS_TRACE_HOST("Layout: " << layout); + + for (int i = 0; i < size(layout); ++i) { + auto coord_i = layout.get_hier_coord(i); + CUTLASS_TRACE_HOST(i << ": " << coord_i); + EXPECT_TRUE(elem_less(coord_i, shape(layout))); + + for (int j = 0; j < size(layout); ++j) { + auto coord_j = layout.get_hier_coord(j); + EXPECT_TRUE(elem_less(coord_j, shape(layout))); + + EXPECT_EQ((i < j), colex_less(coord_i,coord_j)); + } + } +} + +TEST(CuTe_core, Compare_hierarchical_3d_GenRowMajor) +{ + using namespace cute; + auto layout = make_layout(Shape,Shape<_5,_2,_2>>{}, GenRowMajor{}); + CUTLASS_TRACE_HOST("Layout: " << layout); + + for (int i = 0; i < size(layout); ++i) { + auto coord_i = layout.get_hier_coord(i); + CUTLASS_TRACE_HOST(i << ": " << coord_i); + EXPECT_TRUE(elem_less(coord_i, shape(layout))); + + for (int j = 0; j < size(layout); ++j) { + auto coord_j = layout.get_hier_coord(j); + EXPECT_TRUE(elem_less(coord_j, shape(layout))); + + EXPECT_EQ((i < j), lex_less(coord_i,coord_j)); + } + } +} diff --git a/test/unit/cute/core/complement.cpp b/test/unit/cute/core/complement.cpp new file mode 100644 index 0000000000..cfad54ff1a --- /dev/null +++ b/test/unit/cute/core/complement.cpp @@ -0,0 +1,273 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include + +template +void +test_complement(Layout const& layout, CoSizeHi const& cosize_hi) +{ + using namespace cute; + + auto result = complement(layout, cosize_hi); + + CUTLASS_TRACE_HOST("complement( " << layout << ", " << cosize_hi << ") => " << result); + + // Post-condition on the domain size of the complement (1) + EXPECT_GE( size(result), cosize_hi / size(filter(layout))); + // Post-condition on the codomain size of the complement (2) + EXPECT_LE(cosize(result), cute::ceil_div(cosize_hi, cosize(layout)) * cosize(layout)); + + // Post-condition on the codomain of the complement + for (int i = 1; i < size(result); ++i) { + EXPECT_LT(result(i-1), result(i)); // Ordered (3) + for (int j = 0; j < size(layout); ++j) { + EXPECT_NE(result(i), layout(j)); // Complemented (4) + } + } + + // Other observations + EXPECT_LE(size(result),cosize(result)); // As a result of the ordered condition (3) + EXPECT_GE(cosize(result), cosize_hi / size(filter(layout))); // As a result of (1) (2) and (5) + if constexpr (is_static::value) { // If we can apply complement again + EXPECT_EQ(size(complement(make_layout(layout,result))), 1); // There's no more codomain left over + } +} + +template +void +test_complement(Layout const& layout) +{ + return test_complement(layout, cosize(layout)); +} + +TEST(CuTe_core, Complement) +{ + using namespace cute; + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("COMPLEMENT"); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto layout = Layout<_1,_0>{}; + + test_complement(layout); + test_complement(layout, Int<2>{}); + } + + { + auto layout = Layout<_1,_1>{}; + + test_complement(layout); + test_complement(layout, Int<2>{}); + } + + { + auto layout = Layout<_1,_2>{}; + + test_complement(layout, Int<1>{}); + test_complement(layout, Int<2>{}); + test_complement(layout, Int<8>{}); + } + + { + auto layout = Layout<_4,_0>{}; + + test_complement(layout, Int<1>{}); + test_complement(layout, Int<2>{}); + test_complement(layout, Int<8>{}); + } + + { + auto layout = Layout<_4,_1>{}; + + test_complement(layout, Int<1>{}); + test_complement(layout, Int<2>{}); + test_complement(layout, Int<8>{}); + } + + { + auto layout = Layout<_4,_2>{}; + + test_complement(layout, Int<1>{}); + test_complement(layout); + test_complement(layout, Int<16>{}); + } + + { + auto layout = Layout<_4,_4>{}; + + test_complement(layout, Int<1>{}); + test_complement(layout); + test_complement(layout, Int<17>{}); + } + + { + auto layout = Layout>{}; + + test_complement(layout); + } + + { + auto layout = Layout>{}; + + test_complement(layout); + } + + { + auto layout = Layout, Stride<_1,_4>>{}; + + test_complement(layout); + } + + { + auto layout = Layout, Stride<_8,_1,_64>>{}; + + test_complement(layout); + } + + { + auto layout = Layout, Stride<_8,_1,_0>>{}; + + test_complement(layout); + test_complement(layout, Int<460>{}); + } + + { + auto layout = make_layout(Shape,Shape<_2, _2>>{}, + Stride,Stride<_8,_32>>{}); + + test_complement(layout); + } + + { + auto layout = make_layout(Shape,Shape<_2, _2>>{}, + Stride,Stride<_8,_4>>{}); + + test_complement(layout); + } + + // Fails due to non-injective input + //{ + //auto layout = make_layout(Shape,Shape<_2, _2>>{}, + // Stride,Stride<_8,_4>>{}); + + //test_complement(layout); + //} + + { + auto layout = Layout, Stride<_1,_6>>{}; + + test_complement(layout); + } + + { + auto layout = Layout, Stride<_1,_10>>{}; + + test_complement(layout); + } + + { + auto layout = Layout, Stride<_1,_16>>{}; + + test_complement(layout); + } + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("Dynamic shapes/strides"); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto layout = make_layout(12); + + test_complement(layout, 1); + test_complement(layout); + test_complement(layout, 53); + test_complement(layout, 128); + } + + { + auto layout = make_layout(12, 1); + + test_complement(layout, 1); + test_complement(layout); + test_complement(layout, 53); + test_complement(layout, 128); + } + + { + auto layout = make_layout(12, Int<2>{}); + + test_complement(layout, 1); + test_complement(layout); + test_complement(layout, 53); + test_complement(layout, 128); + } + + { + auto layout = make_layout(12, 2); + + test_complement(layout, 1); + test_complement(layout); + test_complement(layout, 53); + test_complement(layout, 128); + } + + { + auto layout = make_layout(make_shape(3,6),make_stride(_1{}, _3{})); + + test_complement(layout); + } + + { + auto layout = make_layout(make_shape(3,6),make_stride(_1{}, _9{})); + + test_complement(layout); + } + + { + auto layout = make_layout(make_shape(3,6),make_stride(_1{}, _10{})); + + test_complement(layout); + } + + { + auto layout = make_layout(make_shape(make_shape(2,2), make_shape(2,2)), + Stride,Stride<_8,_32>>{}); + + test_complement(layout); + } +} diff --git a/test/unit/cute/core/composition.cpp b/test/unit/cute/core/composition.cpp new file mode 100644 index 0000000000..7934b3ceeb --- /dev/null +++ b/test/unit/cute/core/composition.cpp @@ -0,0 +1,528 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include + +#include + +using namespace cute; + + +template +void +test_composition(const LayoutA& layoutA, + const LayoutB& layoutB) +{ + auto layoutR = composition(layoutA, layoutB); + + CUTLASS_TRACE_HOST("test_composition()"); + CUTLASS_TRACE_HOST(layoutA << " o " << layoutB); + CUTLASS_TRACE_HOST(" => "); + CUTLASS_TRACE_HOST(layoutR); + + // Test that layout R is compatible with layout B + EXPECT_TRUE(compatible(layoutB, layoutR)); + + // True post-condition: Every coordinate c of layoutB with L1D(c) < size(layoutR) is a coordinate of layoutR. + + // Test that R(c) = A(B(c)) for all coordinates c in layoutR + for (int i = 0; i < size(layoutR); ++i) { + EXPECT_EQ(layoutR(i), layoutA(layoutB(i))); + } +} + + +TEST(CuTe_core, Composition) +{ + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("COMPOSITION" ); + CUTLASS_TRACE_HOST("-------------------------------"); + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("Simple tests" ); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto a = Layout<_1,_0>{}; + auto b = Layout<_1,_0>{}; + + test_composition(a, b); + } + + { + auto a = Layout<_1,_0>{}; + auto b = Layout<_1,_1>{}; + + test_composition(a, b); + } + + { + auto a = Layout<_1,_1>{}; + auto b = Layout<_1,_0>{}; + + test_composition(a, b); + } + + { + auto a = Layout<_1,_1>{}; + auto b = Layout<_1,_1>{}; + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4>{}); + auto b = make_layout(Shape<_4>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4>{}, Stride<_2>{}); + auto b = make_layout(Shape<_4>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4>{}, Stride<_0>{}); + auto b = make_layout(Shape<_4>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4>{}); + auto b = make_layout(Shape<_4>{}, Stride<_0>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4>{}); + auto b = make_layout(Shape<_1>{}, Stride<_0>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4>{}); + auto b = make_layout(Shape<_2>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4>{}, Stride<_2>{}); + auto b = make_layout(Shape<_2>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4>{}); + auto b = make_layout(Shape<_2>{}, Stride<_2>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4>{}, Stride<_2>{}); + auto b = make_layout(Shape<_2>{}, Stride<_2>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4,_3>{}); + auto b = make_layout(Shape<_12>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_12>{}); + auto b = make_layout(Shape<_4,_3>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_12>{}, Stride<_2>{}); + auto b = make_layout(Shape<_4,_3>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_12>{}); + auto b = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_12>{}, Stride<_2>{}); + auto b = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_12>{}); + auto b = make_layout(Shape<_2,_3>{}, Stride<_2,_4>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4,_3>{}); + auto b = make_layout(Shape<_4,_3>{}); + + test_composition(a, b); + } + + // FAILS due to b not "dividing into" a properly + //{ + // auto a = make_layout(Shape<_4,_3>{}); + // auto b = make_layout(Shape<_6>{}); + + // test_composition(a, b); + //} + + { + auto a = make_layout(Shape<_4,_3>{}); + auto b = make_layout(Shape<_6>{}, Stride<_2>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4,_3>{}); + auto b = make_layout(Shape<_6,_2>{}, Stride<_2,_1>{}); + + test_composition(a, b); + } + + // FAILS due to b not "dividing into" a properly + //{ + // auto a = make_layout(Shape<_4,_3>{}); + // auto b = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); + + // test_composition(a, b); + //} + + { + auto a = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); + auto b = make_layout(Shape<_4,_3>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); + auto b = make_layout(Shape<_12>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); + auto b = make_layout(Shape<_6>{}, Stride<_2>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); + auto b = make_layout(Shape<_6,_2>{}, Stride<_2,_1>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_8,_8>{}); + auto b = make_layout(Shape, Shape<_2,_2, _2>>{}, + Stride, Stride<_8,_2,_32>>{}); + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_8,_8>{}, Stride<_8,_1>{}); + auto b = make_layout(Shape, Shape<_2,_2, _2>>{}, + Stride, Stride<_8,_2,_32>>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape>{}, Stride>{}); + auto b = make_layout(Shape<_4,_2>{}, Stride<_2,_1>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_2,_2>{}, Stride<_2,_1>{}); + auto b = make_layout(Shape<_2,_2>{}, Stride<_2,_1>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4,_8,_2>{}); + auto b = make_layout(Shape<_2,_2,_2>{}, Stride<_2,_8,_1>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4,_8,_2>{}, Stride<_2,_8,_1>{}); + auto b = make_layout(Shape<_2,_2,_2>{}, Stride<_1,_8,_2>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4,_8,_2>{}, Stride<_2,_8,_1>{}); + auto b = make_layout(Shape<_4,_2,_2>{}, Stride<_2,_8,_1>{}); + + test_composition(a, b); + } + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("Dynamic shapes/strides" ); + CUTLASS_TRACE_HOST("-------------------------------"); + + + { + auto a = make_layout(12, 1); + auto b = make_layout(_4{}, _1{}); + + test_composition(a, b); + } + + { + auto a = make_layout(12, 1); + auto b = make_layout(_4{}, 1); + + test_composition(a, b); + } + + { + auto a = make_layout(12, _1{}); + auto b = make_layout(_4{}, 1); + + test_composition(a, b); + } + + { + auto a = make_layout(12, _1{}); + auto b = make_layout(_4{}, _1{}); + + test_composition(a, b); + } + + { + auto a = make_layout(make_shape(12,3), make_stride(1,24)); + auto b = make_layout(Shape<_4>{}, Stride<_1>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(16, 2); + auto b = make_layout(4, 2); + + test_composition(a, b); + } + + { + auto a = make_layout(make_shape(128,24,5), make_stride(1,128,3072)); + auto b = make_layout(64, 2); + + test_composition(a, b); + } + + { + auto a = make_layout(make_shape(128,24,5), make_stride(1,128,3072)); + auto b = make_layout(480, Int<32>{}); + + test_composition(a, b); + } + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("cosize(b) > size(a) and divisibility"); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto a = make_layout(Shape<_1>{}, Stride<_0>{}); + auto b = make_layout(Shape<_4>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_1>{}, Stride<_1>{}); + auto b = make_layout(Shape<_4>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4>{}); + auto b = make_layout(Shape<_4>{}, Stride<_2>{}); + + test_composition(a, b); + } + + // Last mode gets extended + { + auto a = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); + auto b = make_layout(Shape<_24>{}); + + test_composition(a, b); + } + + // Last mode extension even without last mode divisibility + { + auto a = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); + auto b = make_layout(Shape<_8>{}); + + test_composition(a, b); + } + + // Capping a Layout with 1:0 forces divisibility and extends in stride-0 + { + auto a = make_layout(Shape<_4,_3,_1>{}, Stride<_3,_1,_0>{}); + auto b = make_layout(Shape<_24>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(3, _1{}); + auto b = make_layout(_4{}, _1{}); + + test_composition(a, b); + } + + { + auto a = make_layout(make_shape(48,24,5), make_stride(_1{},128,3072)); + auto b = make_layout(32, Int<1>{}); + + test_composition(a, b); + } + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("Swizzle composition" ); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto a = Layout, Stride<_8,_1>>{}; + auto b = composition(Swizzle<2,0,-3>{}, Layout, Stride<_8,_1>>{}); + + test_composition(a, b); + } + + { + auto a = composition(Swizzle<2,0, 3>{}, Layout, Stride<_8,_1>>{}); + auto b = composition(Swizzle<2,0,-3>{}, Layout, Stride<_8,_1>>{}); + + test_composition(a, b); + } + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("BETA: Negative strides" ); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto a = make_layout(Shape<_4>{}, Stride<_m1>{}); + auto b = make_layout(Shape<_4>{}, Stride<_1>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4>{}, Stride<_1>{}); + auto b = make_layout(Shape<_4>{}, Stride<_m1>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4>{}, Stride<_m1>{}); + auto b = make_layout(Shape<_4>{}, Stride<_m1>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4>{}, Stride<_1>{}); + auto b = make_layout(Shape<_4>{}, Stride<_m2>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4,_4>{}, Stride<_m1,_1>{}); + auto b = make_layout(Shape<_2,_4,_2>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4,_4>{}, Stride<_m1,_1>{}); + auto b = make_layout(Shape<_2,_4,_2>{}, Stride<_1,_4,_2>{}); + + test_composition(a, b); + } + + // The SM80 fp64 MMA NT problem + { + auto a = make_layout(Shape<_1,Shape<_2,_4>>{}, Stride<_0,Stride<_m1,_512>>{}); + auto b = make_layout(_2{}, _m1{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_1,Shape<_2,_4>>{}, Stride<_0,Stride<_m1,_512>>{}); + auto b = make_layout(_4{}, _m1{}); + + test_composition(a, b); + } + +} diff --git a/test/unit/cute/core/inverse_left.cpp b/test/unit/cute/core/inverse_left.cpp new file mode 100644 index 0000000000..cea354a838 --- /dev/null +++ b/test/unit/cute/core/inverse_left.cpp @@ -0,0 +1,183 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include + +#include + +using namespace cute; + +template +void +test_left_inverse(Layout const& layout) +{ + auto inv_layout = left_inverse(layout); + + CUTLASS_TRACE_HOST(layout << " ^ -1\n" << " => \n" << inv_layout); + + for (int i = 0; i < size(layout); ++i) { + //printf("%3d: %3d %3d\n", i, int(layout(i)), int(inv_layout(layout(i)))); + EXPECT_EQ(inv_layout(layout(i)), i); + } + + CUTLASS_TRACE_HOST("Composition: " << coalesce(composition(inv_layout, layout))); +} + +TEST(CuTe_core, Inverse_left) +{ + { + auto layout = Layout, + Stride<_0>>{}; + + test_left_inverse(layout); + } + + { + auto layout = Layout>, + Stride>>{}; + + test_left_inverse(layout); + } + + { + auto layout = Layout, + Stride<_1>>{}; + + test_left_inverse(layout); + } + + { + auto layout = Layout, + Stride<_1>>{}; + + test_left_inverse(layout); + } + + { + auto layout = Layout, + Stride<_2>>{}; + + test_left_inverse(layout); + } + + { + auto layout = Layout>{}; + + test_left_inverse(layout); + } + + { + auto layout = Layout, + Stride<_4, _1>>{}; + + test_left_inverse(filter(layout)); + } + + { + auto layout = Layout>{}; + + test_left_inverse(layout); + } + + { + auto layout = Layout, + Stride<_4,_1,_8>>{}; + + test_left_inverse(layout); + } + + { + auto layout = Layout, + Stride<_1,_16>>{}; + + test_left_inverse(layout); + } + + // + // Swizzle left_inverse + // + + { + auto layout = ComposedLayout, _0, Layout, + Stride<_1, _4>>>{}; + + test_left_inverse(layout); + } + + { + auto layout = ComposedLayout, _0, Layout, + Stride<_4, _1>>>{}; + + test_left_inverse(layout); + } + + { + auto layout = ComposedLayout, _0, Layout, + Stride<_8, _1>>>{}; + + test_left_inverse(layout); + } + + // + // Negative strides (beta support) + // Post-conditions/layout indexing aren't generalized enough to support these yet + // However, the composition post-condition is general enough. + { + auto layout = make_layout(Shape<_4>{}, Stride>{}); + + test_left_inverse(layout); + } + + //{ + //auto layout = Layout, + // Stride<_m1,_2>>{}; + + //test_left_inverse(layout); + //} + + //{ + //auto layout = Layout, + // Stride< _4,_m1>>{}; + + //test_left_inverse(layout); + //} + + //{ + //auto layout = Layout, + // Stride<_m1,_12,_m2>>{}; + + //test_left_inverse(layout); + //} +} diff --git a/test/unit/cute/core/inverse_right.cpp b/test/unit/cute/core/inverse_right.cpp new file mode 100644 index 0000000000..4bb5870bff --- /dev/null +++ b/test/unit/cute/core/inverse_right.cpp @@ -0,0 +1,255 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include + +#include + +using namespace cute; + +template +void +test_right_inverse(Layout const& layout) +{ + auto inv_layout = right_inverse(layout); + + CUTLASS_TRACE_HOST(layout << " ^ -1\n" << " => \n" << inv_layout); + CUTLASS_TRACE_HOST("Composition: " << coalesce(composition(layout, inv_layout)) << std::endl); + + for (int i = 0; i < size(inv_layout); ++i) { + //printf("%3d: %3d %3d\n", i, int(inv_layout(i)), int(layout(inv_layout(i)))); + EXPECT_EQ(layout(inv_layout(i)), i); + } +} + +TEST(CuTe_core, Inverse_right) +{ + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("RIGHT INVERSE" ); + CUTLASS_TRACE_HOST("-------------------------------"); + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("Simple tests" ); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto layout = Layout<_1, _0>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout<_1, _1>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout, + Stride<_0>>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout>, + Stride>>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout>, + Stride>>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout, + Stride<_1>>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout, + Stride<_1>>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout, + Stride<_2>>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout, + Stride<_0,_2>>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout, + Stride<_4, _1>>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout, + Stride<_4,_1,_8>>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout, + Stride<_4,_1,_0,_8>>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout, + Stride<_1,_16>>{}; + + test_right_inverse(layout); + } + + { + auto layout = Layout, + Stride<_1, _5>>{}; + + test_right_inverse(layout); + } + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("Dynamic shapes/strides" ); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto layout = make_layout(Shape<_4, _2>{}, make_stride(Int<1>{}, 4)); + + test_right_inverse(layout); + } + + { + auto layout = make_layout(make_shape(_4{}, 2), make_stride(Int<1>{}, 4)); + + test_right_inverse(layout); + } + + { + auto layout = make_layout(make_shape(4, 2), make_stride(Int<1>{}, 4)); + + test_right_inverse(layout); + } + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("Swizzle layouts" ); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto layout = ComposedLayout, _0, Layout, + Stride<_1, _4>>>{}; + + test_right_inverse(layout); + } + + { + auto layout = ComposedLayout, _0, Layout, + Stride<_4, _1>>>{}; + + test_right_inverse(layout); + } + + { + auto layout = ComposedLayout, _0, Layout, + Stride<_8, _1>>>{}; + + test_right_inverse(layout); + } + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("BETA: Negative strides" ); + CUTLASS_TRACE_HOST("-------------------------------"); + + // Negative strides (beta support) + // Post-conditions/layout indexing aren't generalized enough to support these yet + // However, the composition post-condition is general enough. + { + auto layout = make_layout(Shape<_4>{}, Stride>{}); + + test_right_inverse(layout); + } + + //{ + //auto layout = Layout, + // Stride<_m1,_2>>{}; + + //test_right_inverse(layout); + //} + + //{ + //auto layout = Layout, + // Stride< _4,_m1>>{}; + + //test_right_inverse(layout); + //} + + //{ + //auto layout = Layout, + // Stride<_m1,_12,_m2>>{}; + + //test_right_inverse(layout); + //} + +} diff --git a/test/unit/cute/core/logical_divide.cpp b/test/unit/cute/core/logical_divide.cpp new file mode 100644 index 0000000000..5d37b8295f --- /dev/null +++ b/test/unit/cute/core/logical_divide.cpp @@ -0,0 +1,253 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include + +using namespace cute; + +template +void +test_logical_divide(LayoutA const& layoutA, + LayoutB const& layoutB) +{ + auto layoutR = logical_divide(layoutA, layoutB); + + CUTLASS_TRACE_HOST("test_logical_divide()"); + CUTLASS_TRACE_HOST(shape(layoutA) << " / " << shape(layoutB) << " => " << shape(layoutR) ); + CUTLASS_TRACE_HOST(stride(layoutA) << " " << stride(layoutB) << " => " << stride(layoutR)); + + // Test that layout R is compatible with layout B + ASSERT_EQ(rank(layoutR), 2); + ASSERT_TRUE(compatible(layoutB, layout<0>(layoutR))); +} + +TEST(CuTe_core, Logical_divide) +{ + { + auto layout = Layout<_1,_0>{}; + auto tile = Layout<_1,_0>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout<_1,_0>{}; + auto tile = Layout<_1,_1>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout<_1,_1>{}; + auto tile = Layout<_1,_0>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout<_1,_1>{}; + auto tile = Layout<_1,_1>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout<_6,_1>{}; + auto tile = Layout<_2,_1>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout<_6,_1>{}; + auto tile = Layout<_2,_3>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout<_6,_1>{}; + auto tile = Layout,Stride<_3,_1>>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout<_6,_2>{}; + auto tile = Layout<_2,_1>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout<_6,_2>{}; + auto tile = Layout<_2,_3>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout<_6,_2>{}; + auto tile = Layout,Stride<_3,_1>>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout,Stride<_1,_12>>{}; + auto tile = Layout,Stride<_3,_1>>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout,Stride<_12,_1>>{}; + auto tile = Layout,Stride<_3,_1>>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout<_32>{}; + auto tile = Layout<_2,_8>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout,Stride<_1,_1>>{}; + auto tile = Layout<_2,_1>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout,Stride<_1,_1>>{}; + auto tile = Layout<_2,_2>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout,Stride<_1,_8>>{}; + auto tile = Layout<_32,_2>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = Layout,Stride<_8,_1>>{}; + auto tile = Layout<_32,_2>{}; + + test_logical_divide(layout, tile); + } + + // + // Dynamic + // + + { + auto layout = make_layout(2); + auto tile = Layout<_32>{}; + + test_logical_divide(layout, tile); + + // Enforcement for dynamic cases + auto result = logical_divide(layout, tile); + static_assert(decltype(shape<0>(result) == Int<32>{})::value); + static_assert(decltype(stride<0>(result) == Int<1>{})::value); + assert(shape<1>(result) == 1); + static_assert(decltype(stride<1>(result) == Int<32>{})::value); + } + + { + auto layout = make_layout(48); + auto tile = Layout<_32>{}; + + test_logical_divide(layout, tile); + + // Enforcement for dynamic cases + auto result = logical_divide(layout, tile); + static_assert(decltype(shape<0>(result) == Int<32>{})::value); + static_assert(decltype(stride<0>(result) == Int<1>{})::value); + assert(shape<1>(result) == 2); + static_assert(decltype(stride<1>(result) == Int<32>{})::value); + } + + { + auto layout = make_layout(96); + auto tile = Layout<_32,_2>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = make_layout(32); + auto tile = Layout>{}; + + test_logical_divide(layout, tile); + + // Enforcement for dynamic cases + auto result = logical_divide(layout, tile); + static_assert(decltype(shape<0>(result) == Int<48>{})::value); + static_assert(decltype(stride<0>(result) == Int<1>{})::value); + assert(shape<1>(result) == 1); + static_assert(decltype(stride<1>(result) == Int<48>{})::value); + } + + // DISALLOWED + //{ + //auto layout = make_layout(make_shape(128,4,3), make_stride(1,512,0)); + //auto tile = Layout<_32>{}; + + //test_logical_divide(layout, tile); + //} + + //{ + //auto layout = make_layout(make_shape(128,4,3), make_stride(1,512,0)); + //auto tile = Layout<_32,_2>{}; + + //CUTLASS_TRACE_HOST("complement: " << complement(tile, size(layout))); + //test_logical_divide(layout, tile); + //} + + //{ + //auto layout = make_layout(make_shape(16,4,3), make_stride(1,512,0)); + //auto tile = Layout<_32>{}; + + //CUTLASS_TRACE_HOST("complement: " << complement(tile, size(layout))); + //test_logical_divide(layout, tile); + //} +} diff --git a/test/unit/cute/core/logical_product.cpp b/test/unit/cute/core/logical_product.cpp new file mode 100644 index 0000000000..bcdae4ea93 --- /dev/null +++ b/test/unit/cute/core/logical_product.cpp @@ -0,0 +1,218 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include + +using namespace cute; + +template +void +test_logical_product(LayoutA const& layoutA, + LayoutB const& layoutB) +{ + auto layoutR = logical_product(layoutA, layoutB); + + CUTLASS_TRACE_HOST(shape(layoutA) << " x " << shape(layoutB) << " => " << shape(layoutR) ); + CUTLASS_TRACE_HOST(stride(layoutA) << " " << stride(layoutB) << " => " << stride(layoutR)); + + // Test that layout R is compatible with layout B + ASSERT_EQ(rank(layoutR), 2); + //assert(compatible(layoutB, layout<0>(layoutR))); + //assert(consistent(layoutA, layout<1>(layoutR))); + + // True post-condition: + +} + +TEST(CuTe_core, Logical_product) +{ + { + auto vec = Layout<_1,_0>{}; + auto tile = Layout<_1,_0>{}; + + test_logical_product(vec, tile); + } + + { + auto vec = Layout<_1,_1>{}; + auto tile = Layout<_1,_0>{}; + + test_logical_product(vec, tile); + } + + { + auto vec = Layout<_1,_0>{}; + auto tile = Layout<_1,_1>{}; + + test_logical_product(vec, tile); + } + + { + auto vec = Layout<_1,_1>{}; + auto tile = Layout<_1,_1>{}; + + test_logical_product(vec, tile); + } + + { + auto vec = Layout<_3,_1>{}; + auto tile = Layout<_4,_0>{}; + + test_logical_product(vec, tile); + } + + { + auto vec = Layout<_3,_0>{}; + auto tile = Layout<_4,_1>{}; + + test_logical_product(vec, tile); + } + + { + auto vec = Layout<_3,_0>{}; + auto tile = Layout<_4,_0>{}; + + test_logical_product(vec, tile); + } + + { + auto vec = Layout<_3,_2>{}; + auto tile = Layout<_4,_1>{}; + + test_logical_product(vec, tile); + } + + { + auto vec = make_layout(Shape<_3>{}); + auto tile = make_layout(Shape<_2,_4>{}); + + test_logical_product(vec, tile); + } + + { + auto vec = make_layout(Shape<_2,_4>{}); + auto tile = make_layout(Shape<_3>{}); + + test_logical_product(vec, tile); + } + + { + auto vec = make_layout(Shape<_8,Shape<_2,_2>>{}); + auto tile = make_layout(Shape<_4>{}, Stride<_2>{}); + + test_logical_product(vec, tile); + } + + { + auto vec = make_layout(Shape<_2,_2>{}); + auto tile = make_layout(Shape<_3,_3>{}, Stride<_3,_1>{}); + + test_logical_product(vec, tile); + } + + { + auto vec = make_layout(Shape<_3>{}, Stride<_32>{}); + auto tile = make_layout(Shape<_32>{}); + + test_logical_product(vec, tile); + } + + { + auto vec = make_layout(Shape<_3>{}, Stride<_2>{}); + auto tile = make_layout(Shape<_4>{}); + + test_logical_product(vec, tile); + } + + { + auto vec = make_layout(Shape<_3>{}, Stride<_32>{}); + auto tile = make_layout(Shape<_128>{}); + + test_logical_product(vec, tile); + } + + { + auto vec = make_layout(Shape<_3>{}, Stride<_32>{}); + auto tile = make_layout(Shape<_8,_8>{}); + + test_logical_product(vec, tile); + } + + { + auto vec = make_layout(Shape<_3>{}, Stride<_32>{}); + auto tile = make_layout(Shape<_8,_8>{}, Stride<_8,_1>{}); + + test_logical_product(vec, tile); + } + + { + auto vec = make_layout(Shape>{}, Stride>{}); + auto tile = make_layout(Shape<_4,_4>{}); + + test_logical_product(vec, tile); + } + + { + auto vec = make_layout(Shape>{}, Stride>{}); + auto tile = make_layout(Shape<_4,_2>{}, Stride<_2,_1>{}); + + test_logical_product(vec, tile); + } + + { + auto vec = make_layout(Shape,Shape<_2, _2>>{}, + Stride,Stride<_8,_32>>{}); + auto tile = make_layout(Shape<_2,_2>{}, Stride<_1,_2>{}); + + test_logical_product(vec, tile); + } + + { + auto vec = make_layout(Shape,Shape<_2, _2>>{}, + Stride,Stride<_8,_32>>{}); + auto tile = make_layout(Shape<_2,_2>{}, + Stride<_2,_1>{}); + + test_logical_product(vec, tile); + } + + { + auto vec = make_layout(Shape >{}, + Stride>{}); + auto tile = make_layout(Shape <_3>{}, + Stride<_1>{}); + + test_logical_product(vec, tile); + } +} diff --git a/test/unit/cute/core/mixedbits.cpp b/test/unit/cute/core/mixedbits.cpp new file mode 100644 index 0000000000..55027ebd24 --- /dev/null +++ b/test/unit/cute/core/mixedbits.cpp @@ -0,0 +1,70 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include + +TEST(CuTe_core, MixedBits) { + using namespace cute; + + auto uzero = cute::integral_constant{}; + + for_each(make_integer_sequence{}, [&](auto S0) { + for_each(make_integer_sequence{}, [&](auto F0) { + for_each(make_integer_sequence{}, [&](auto S1) { + for_each(make_integer_sequence{}, [&](auto F1) { + if constexpr (decltype(S0 == uzero || S1 == uzero)::value) { + return; + } else if constexpr (decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value) { + return; + } else { + for (uint32_t d0 = 0; d0 < 8; ++d0) { + if ((d0 & F0) != d0) { continue; } // Skip repeats + for (uint32_t d1 = 0; d1 < 8; ++d1) { + if ((d1 & F1) != d1) { continue; } // Skip repeats + auto m0 = make_mixed_bits(S0, d0, F0); + auto m1 = make_mixed_bits(S1, d1, F1); + //print(m0); print(" & "); print(m1); print(" = "); print(m0 & m1); print("\n"); + EXPECT_EQ(to_integral(m0 & m1), to_integral(m0) & to_integral(m1)); + //print(m0); print(" | "); print(m1); print(" = "); print(m0 | m1); print("\n"); + EXPECT_EQ(to_integral(m0 | m1), to_integral(m0) | to_integral(m1)); + //print(m0); print(" ^ "); print(m1); print(" = "); print(m0 ^ m1); print("\n"); + EXPECT_EQ(to_integral(m0 ^ m1), to_integral(m0) ^ to_integral(m1)); + } + } + } + }); + }); + }); + }); +} diff --git a/test/unit/cute/core/transform.cpp b/test/unit/cute/core/transform.cpp new file mode 100644 index 0000000000..c929f401ab --- /dev/null +++ b/test/unit/cute/core/transform.cpp @@ -0,0 +1,49 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include +#include + +TEST(CuTe_core, Transform) { + using namespace cute; + complex array[4] = {{0,0}, {1,0}, {0,1}, {1,1}}; + complex correct[4] = {{0,0}, {1,0}, {0,-1}, {1,-1}}; + auto tensor = make_tensor(static_cast*>(array), make_layout(make_shape(4))); + conjugate conj; + transform(tensor, conj); + for (int i = 0; i < 4; ++i) + { + EXPECT_EQ(tensor(i), correct[i]); + } +} diff --git a/test/unit/cute/core/tuple.cpp b/test/unit/cute/core/tuple.cpp new file mode 100644 index 0000000000..a53121aaa5 --- /dev/null +++ b/test/unit/cute/core/tuple.cpp @@ -0,0 +1,266 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include + +TEST(CuTe_core, Tuple) +{ + using namespace cute; + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("SIMPLE STATIC AND DYNAMIC TUPLES"); + CUTLASS_TRACE_HOST("-------------------------------"); + + using tuple_2d_s_type = tuple<_8, _4>; // (8,4) + using tuple_3d_s_type = tuple<_8, _4, _2>; // (8,4,2) + using tuple_3h_s_type = tuple, _8, _2>; // ((1,2),8,2) + + using tuple_2d_d_type = tuple; // (8,4) + using tuple_3d_d_type = tuple; // (8,4,2) + using tuple_3h_d_type = tuple, int, int>; // ((1,2),8,2) + + using tuple_2d_m_type = tuple<_8, int>; // (8,4) + using tuple_3d_m_type = tuple; // (8,4,2) + using tuple_3h_m_type = tuple, int, int>; // ((1,2),8,2) + + tuple_2d_s_type tuple_2d_s; + tuple_3d_s_type tuple_3d_s; + tuple_3h_s_type tuple_3h_s; + + tuple_2d_d_type tuple_2d_d(8,4); + tuple_3d_d_type tuple_3d_d(8,4,2); + tuple_3h_d_type tuple_3h_d(tuple(1,2),8,2); + + tuple_2d_m_type tuple_2d_m(_8{}, 4); + tuple_3d_m_type tuple_3d_m(8,4,_2{}); + tuple_3h_m_type tuple_3h_m(tuple(1,_2{}),8,2); + + CUTLASS_TRACE_HOST(tuple_2d_s << (is_static::value ? " Static " : " Dynamic ") + << "sizeof = " << sizeof(tuple_2d_s_type)); + ASSERT_TRUE(is_static::value == true); + ASSERT_TRUE(sizeof(tuple_2d_s_type) == 1); + ASSERT_TRUE(std::is_empty::value); + + CUTLASS_TRACE_HOST(tuple_3d_s << (is_static::value ? " Static " : " Dynamic ") + << "sizeof = " << sizeof(tuple_3d_s_type)); + ASSERT_TRUE(is_static::value == true); + ASSERT_TRUE(sizeof(tuple_3d_s_type) == 1); + ASSERT_TRUE(std::is_empty::value); + + CUTLASS_TRACE_HOST(tuple_3h_s << (is_static::value ? " Static " : " Dynamic ") + << "sizeof = " << sizeof(tuple_3h_s_type)); + ASSERT_TRUE(is_static::value == true); + ASSERT_TRUE(sizeof(tuple_3h_s_type) == 1); + ASSERT_TRUE(std::is_empty::value); + + CUTLASS_TRACE_HOST(tuple_2d_d << (is_static::value ? " Static " : " Dynamic ") + << "sizeof = " << sizeof(tuple_2d_d_type)); + ASSERT_TRUE(is_static::value == false); + ASSERT_TRUE(sizeof(tuple_2d_d_type) == 8); + ASSERT_TRUE(!std::is_empty::value); + + CUTLASS_TRACE_HOST(tuple_3d_d << (is_static::value ? " Static " : " Dynamic ") + << "sizeof = " << sizeof(tuple_3d_d_type)); + ASSERT_TRUE(is_static::value == false); + ASSERT_TRUE(sizeof(tuple_3d_d_type) == 12); + ASSERT_TRUE(!std::is_empty::value); + + CUTLASS_TRACE_HOST(tuple_3h_d << (is_static::value ? " Static " : " Dynamic ") + << "sizeof = " << sizeof(tuple_3h_d_type)); + ASSERT_TRUE(is_static::value == false); + ASSERT_TRUE(sizeof(tuple_3h_d_type) == 16); + ASSERT_TRUE(!std::is_empty::value); + + CUTLASS_TRACE_HOST(tuple_2d_m << (is_static::value ? " Static " : " Dynamic ") + << "sizeof = " << sizeof(tuple_2d_m_type)); + ASSERT_TRUE(is_static::value == false); + ASSERT_TRUE(sizeof(tuple_2d_m_type) == 4); + ASSERT_TRUE(!std::is_empty::value); + + CUTLASS_TRACE_HOST(tuple_3d_m << (is_static::value ? " Static " : " Dynamic ") + << "sizeof = " << sizeof(tuple_3d_m_type)); + ASSERT_TRUE(is_static::value == false); + ASSERT_TRUE(sizeof(tuple_3d_m_type) == 8); + ASSERT_TRUE(!std::is_empty::value); + + CUTLASS_TRACE_HOST(tuple_3h_m << (is_static::value ? " Static " : " Dynamic ") + << "sizeof = " << sizeof(tuple_3h_m_type)); + ASSERT_TRUE(is_static::value == false); + ASSERT_TRUE(sizeof(tuple_3h_m_type) == 12); + ASSERT_TRUE(!std::is_empty::value); + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("SIMPLE TUPLE OPS"); + CUTLASS_TRACE_HOST("-------------------------------"); + + CUTLASS_TRACE_HOST("product(" << tuple_2d_s << ") => " << product(tuple_2d_s)); + CUTE_STATIC_ASSERT_V(product(tuple_2d_s) == _32{}); + CUTLASS_TRACE_HOST("product(" << tuple_3d_s << ") => " << product(tuple_3d_s)); + CUTE_STATIC_ASSERT_V(product(tuple_3d_s) == _64{}); + CUTLASS_TRACE_HOST("product(" << tuple_3h_s << ") => " << product(tuple_3h_s)); + CUTE_STATIC_ASSERT_V(product(tuple_3h_s) == _32{}); + + CUTLASS_TRACE_HOST("product(" << tuple_2d_d << ") => " << product(tuple_2d_d)); + ASSERT_TRUE(product(tuple_2d_d) == 32); + CUTLASS_TRACE_HOST("product(" << tuple_3d_d << ") => " << product(tuple_3d_d)); + ASSERT_TRUE(product(tuple_3d_d) == 64); + CUTLASS_TRACE_HOST("product(" << tuple_3h_d << ") => " << product(tuple_3h_d)); + ASSERT_TRUE(product(tuple_3h_d) == 32); + + CUTLASS_TRACE_HOST("product(" << tuple_2d_m << ") => " << product(tuple_2d_m)); + ASSERT_TRUE(product(tuple_2d_m) == 32); + CUTLASS_TRACE_HOST("product(" << tuple_3d_m << ") => " << product(tuple_3d_m)); + ASSERT_TRUE(product(tuple_3d_m) == 64); + CUTLASS_TRACE_HOST("product(" << tuple_3h_m << ") => " << product(tuple_3h_m)); + ASSERT_TRUE(product(tuple_3h_m) == 32); + + CUTLASS_TRACE_HOST("max(" << tuple_2d_s << ") => " << max(tuple_2d_s)); + CUTE_STATIC_ASSERT_V(max(tuple_2d_s) == _8{}); + CUTLASS_TRACE_HOST("max(" << tuple_3d_s << ") => " << max(tuple_3d_s)); + CUTE_STATIC_ASSERT_V(max(tuple_3d_s) == _8{}); + CUTLASS_TRACE_HOST("max(" << tuple_3h_s << ") => " << max(tuple_3h_s)); + CUTE_STATIC_ASSERT_V(max(tuple_3h_s) == _8{}); + + CUTLASS_TRACE_HOST("max(" << tuple_2d_d << ") => " << max(tuple_2d_d)); + ASSERT_TRUE(max(tuple_2d_d) == 8); + CUTLASS_TRACE_HOST("max(" << tuple_3d_d << ") => " << max(tuple_3d_d)); + ASSERT_TRUE(max(tuple_3d_d) == 8); + CUTLASS_TRACE_HOST("max(" << tuple_3h_d << ") => " << max(tuple_3h_d)); + ASSERT_TRUE(max(tuple_3h_d) == 8); + + CUTLASS_TRACE_HOST("max(" << tuple_2d_m << ") => " << max(tuple_2d_m)); + ASSERT_TRUE(max(tuple_2d_m) == 8); + CUTLASS_TRACE_HOST("max(" << tuple_3d_m << ") => " << max(tuple_3d_m)); + ASSERT_TRUE(max(tuple_3d_m) == 8); + CUTLASS_TRACE_HOST("max(" << tuple_3h_m << ") => " << max(tuple_3h_m)); + ASSERT_TRUE(max(tuple_3h_m) == 8); + + // 2d s|d|m + CUTLASS_TRACE_HOST("inner_product(" << tuple_2d_s << ", " << tuple_2d_s << ") => " + << inner_product(tuple_2d_s, tuple_2d_s)); + CUTE_STATIC_ASSERT_V(inner_product(tuple_2d_s, tuple_2d_s) == Int<80>{}); + CUTLASS_TRACE_HOST("inner_product(" << tuple_2d_d << ", " << tuple_2d_d << ") => " + << inner_product(tuple_2d_d, tuple_2d_d)); + ASSERT_TRUE(inner_product(tuple_2d_d, tuple_2d_d) == 80); + CUTLASS_TRACE_HOST("inner_product(" << tuple_2d_m << ", " << tuple_2d_m << ") => " + << inner_product(tuple_2d_m, tuple_2d_m)); + ASSERT_TRUE(inner_product(tuple_2d_m, tuple_2d_m) == 80); + + // 3d s|d|m + CUTLASS_TRACE_HOST("inner_product(" << tuple_3d_s << ", " << tuple_3d_s << ") => " + << inner_product(tuple_3d_s, tuple_3d_s)); + CUTE_STATIC_ASSERT_V(inner_product(tuple_3d_s, tuple_3d_s) == Int<84>{}); + CUTLASS_TRACE_HOST("inner_product(" << tuple_3d_d << ", " << tuple_3d_d << ") => " + << inner_product(tuple_3d_d, tuple_3d_d)); + ASSERT_TRUE(inner_product(tuple_3d_d, tuple_3d_d) == 84); + CUTLASS_TRACE_HOST("inner_product(" << tuple_3d_m << ", " << tuple_3d_m << ") => " + << inner_product(tuple_3d_m, tuple_3d_m)); + ASSERT_TRUE(inner_product(tuple_3d_m, tuple_3d_m) == 84); + + // 3h s|d|m + CUTLASS_TRACE_HOST("inner_product(" << tuple_3h_s << ", " << tuple_3h_s << ") => " + << inner_product(tuple_3h_s, tuple_3h_s)); + CUTE_STATIC_ASSERT_V(inner_product(tuple_3h_s, tuple_3h_s) == Int<73>{}); + CUTLASS_TRACE_HOST("inner_product(" << tuple_3h_d << ", " << tuple_3h_d << ") => " + << inner_product(tuple_3h_d, tuple_3h_d)); + ASSERT_TRUE(inner_product(tuple_3h_d, tuple_3h_d) == 73); + CUTLASS_TRACE_HOST("inner_product(" << tuple_3h_m << ", " << tuple_3h_m << ") => " + << inner_product(tuple_3h_m, tuple_3h_m)); + ASSERT_TRUE(inner_product(tuple_3h_m, tuple_3h_m) == 73); + + CUTLASS_TRACE_HOST("col_major(" << tuple_2d_s << ") => " << compact_col_major(tuple_2d_s)); + CUTE_STATIC_ASSERT_V((compact_col_major(tuple_2d_s) == make_tuple(_1{},_8{}))); + CUTLASS_TRACE_HOST("col_major(" << tuple_3d_s << ") => " << compact_col_major(tuple_3d_s)); + CUTE_STATIC_ASSERT_V((compact_col_major(tuple_3d_s) == make_tuple(_1{},_8{},_32{}))); + CUTLASS_TRACE_HOST("col_major(" << tuple_3h_s << ") => " << compact_col_major(tuple_3h_s)); + CUTE_STATIC_ASSERT_V((compact_col_major(tuple_3h_s) == make_tuple(make_tuple(_0{},_1{}),_2{},_16{}))); + + CUTLASS_TRACE_HOST("col_major(" << tuple_2d_d << ") => " << compact_col_major(tuple_2d_d)); + ASSERT_TRUE((compact_col_major(tuple_2d_d) == make_tuple(_1{},8))); + CUTLASS_TRACE_HOST("col_major(" << tuple_3d_d << ") => " << compact_col_major(tuple_3d_d)); + ASSERT_TRUE((compact_col_major(tuple_3d_d) == make_tuple(_1{},8,32))); + CUTLASS_TRACE_HOST("col_major(" << tuple_3h_d << ") => " << compact_col_major(tuple_3h_d)); + ASSERT_TRUE((compact_col_major(tuple_3h_d) == make_tuple(make_tuple(_1{},1),2,16))); + + CUTLASS_TRACE_HOST("col_major(" << tuple_2d_m << ") => " << compact_col_major(tuple_2d_m)); + ASSERT_TRUE((compact_col_major(tuple_2d_m) == make_tuple(_1{},_8{}))); + CUTLASS_TRACE_HOST("col_major(" << tuple_3d_m << ") => " << compact_col_major(tuple_3d_m)); + ASSERT_TRUE((compact_col_major(tuple_3d_m) == make_tuple(_1{},8,32))); + CUTLASS_TRACE_HOST("col_major(" << tuple_3h_m << ") => " << compact_col_major(tuple_3h_m)); + ASSERT_TRUE((compact_col_major(tuple_3h_m) == make_tuple(make_tuple(_1{},1),2,16))); + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("SLICING TUPLES"); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto a = Coord<_2,_3,_4,Coord<_5,_6>>{}; + + CUTLASS_TRACE_HOST("a = " << a); + + CUTLASS_TRACE_HOST("a(1) = " << slice(1, a)); + + CUTLASS_TRACE_HOST("a(_) = " << slice(_, a)); + + CUTLASS_TRACE_HOST("a(_,1,_,_) = " << slice(make_coord(_,1,_,_), a)); + + CUTLASS_TRACE_HOST("a(_,1,_,(_,_)) = " << slice(make_coord(_,1,_,make_coord(_,_)), a)); + + CUTLASS_TRACE_HOST("a(_,1,_,(_,2)) = " << slice(make_coord(_,1,_,make_coord(_,2)), a)); + + CUTLASS_TRACE_HOST("a(_,1,_,(1,2)) = " << slice(make_coord(_,1,_,make_coord(1,2)), a)); + } + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("DICING TUPLES"); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto a = Coord<_2,_3,_4,Coord<_5,_6>>{}; + + CUTLASS_TRACE_HOST("a = " << a); + + CUTLASS_TRACE_HOST("a(1) = " << dice(1, a)); + + CUTLASS_TRACE_HOST("a(_) = " << dice(_, a)); + + CUTLASS_TRACE_HOST("a(_,1,_,_) = " << dice(make_coord(_,1,_,_), a)); + + CUTLASS_TRACE_HOST("a(_,1,_,(_,_)) = " << dice(make_coord(_,1,_,make_coord(_,_)), a)); + + CUTLASS_TRACE_HOST("a(_,1,_,(_,2)) = " << dice(make_coord(_,1,_,make_coord(_,2)), a)); + + CUTLASS_TRACE_HOST("a(_,1,_,(1,2)) = " << dice(make_coord(_,1,_,make_coord(1,2)), a)); + } +} diff --git a/test/unit/cute/hopper/CMakeLists.txt b/test/unit/cute/hopper/CMakeLists.txt new file mode 100644 index 0000000000..ce30110113 --- /dev/null +++ b/test/unit/cute/hopper/CMakeLists.txt @@ -0,0 +1,58 @@ +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +add_custom_target( + cutlass_test_unit_cute_hopper + DEPENDS + cutlass_test_unit_cute_hopper_stsm + cutlass_test_unit_cute_hopper_tma_load + cutlass_test_unit_cute_hopper_tma_store +) + +add_custom_target( + test_unit_cute_hopper + DEPENDS + test_unit_cute_hopper_stsm + test_unit_cute_hopper_tma_load + test_unit_cute_hopper_tma_store +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_hopper_stsm + stsm.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_hopper_tma_load + tma_load.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_hopper_tma_store + tma_store.cu +) diff --git a/test/unit/cute/hopper/stsm.cu b/test/unit/cute/hopper/stsm.cu new file mode 100644 index 0000000000..ffc8aa74fc --- /dev/null +++ b/test/unit/cute/hopper/stsm.cu @@ -0,0 +1,426 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include +#include + +#include +#include + +using namespace cute; + +template +__global__ void +stsm_test_device(uint16_t* g_in, uint16_t* g_out) +{ + constexpr int count = sizeof(T) / 4; + int tid = threadIdx.x; + int stride = blockDim.x; + + // load input gmem -> rmem + uint32_t reg[count]; + for (int i = 0; i < (sizeof(T) / 4); i++) { + reg[i] = reinterpret_cast(g_in)[tid + (stride * i)]; + } + + __shared__ uint32_t smem[32 * count]; + + // load rmem -> smem using STSM + uint128_t* smem_ptr = reinterpret_cast(smem) + tid; + T* rmem_ptr = reinterpret_cast(reg); + cute::copy_stsm(rmem_ptr, smem_ptr); + + __syncthreads(); + + // store output smem -> gmem + for (int i = 0; i < (sizeof(T) / 4); i++) { + reinterpret_cast(g_out)[tid + (stride * i)] = smem[tid + (stride * i)]; + } +} + +template +__global__ void +stsm_test_device_cute(uint16_t* g_in, uint16_t* g_out, + TiledCopy tiled_copy, SmemLayout smem_layout) +{ + using namespace cute; + + __shared__ uint16_t smem[size(smem_layout)]; + + Tensor t_g_in = make_tensor(make_gmem_ptr(g_in), smem_layout); + Tensor t_g_out = make_tensor(make_gmem_ptr(g_out), smem_layout); + Tensor t_smem = make_tensor(make_smem_ptr(smem), smem_layout); + + int tid = threadIdx.x; + + auto thr_copy = tiled_copy.get_thread_slice(tid); + + Tensor tXgX = thr_copy.partition_S(t_g_in); // (V,M,N) + Tensor tXsX = thr_copy.partition_D(t_smem); // (V,M,N) + + Tensor tXrX = make_tensor(shape(tXgX)); // (V,M,N) + clear(tXrX); // Just to make sure + +/* + if (thread0()) { + print("tXsX: " ); print(tXsX.layout()); print("\n"); + print("tXgX: " ); print(tXgX.layout()); print("\n"); + print("tXrX: " ); print(tXrX.layout()); print("\n"); + } +*/ + + // Load input gmem -> rmem + copy(tXgX, tXrX); + + // Copy rmem -> smem via tiled_copy (STSM, STS) + copy(tiled_copy, tXrX, tXsX); + + // Output smem -> gmem + for (int i = tid; i < size(t_smem); i += size(tiled_copy)) { + t_g_out(i) = t_smem(i); + } +} + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED +TEST(SM90_CuTe_Hopper, Stsm) +{ + constexpr int count = 1024; + + thrust::host_vector h_in(count); + for (int i = 0; i < count; ++i) { + h_in[i] = uint16_t(i); + } + thrust::device_vector d_in = h_in; + + // + // STSM 1x (32b) + // + + { + thrust::device_vector d_out(count); + stsm_test_device<<<1, 32>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data())); + thrust::host_vector h_out = d_out; + for (int i = 0; i < 32; ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("STSM 1x stsm_test_device SUCCESS\n"); + } + + // + // STSM 2x (64b) + // + + { + thrust::device_vector d_out(count); + stsm_test_device<<<1, 32>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data())); + thrust::host_vector h_out = d_out; + for (int i = 0; i < 64; ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("STSM 2x stsm_test_device SUCCESS\n"); + } + + // + // STSM 4x (128b) + // + + { + thrust::device_vector d_out(count); + stsm_test_device<<<1, 32>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data())); + thrust::host_vector h_out = d_out; + for (int i = 0; i < 128; ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("STSM 4x stsm_test_device SUCCESS\n"); + } + + // + // CuTe STSM + // + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout>, + Stride< _2,Stride<_1,_64>>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x1_STSM_N SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout>, + Stride< _2,Stride<_1,_64>>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x2_STSM_N SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout>, + Stride< _2,Stride<_1,_64>>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x4_STSM_N SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout>, + Stride< _2,Stride<_1,_64>>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom, uint16_t>{}, + Layout>{}, + Layout>{}); + + stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x8 interleaved STS.U16 SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout, + Stride< _1,_32>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x32 U32x1_STSM_N SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout, + Stride< _1,_32>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x32 U32x2_STSM_N SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout, + Stride< _1,_32>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x32 U32x4_STSM_N SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout, + Stride< _1,_32>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom, uint16_t>{}, + Layout>{}, + Layout>{}); + + stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x32 STS.U16 SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout, + Stride<_32, _1>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x32 U16x2_STSM_T SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout, + Stride<_32, _1>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x32 U16x4_STSM_T SUCCESS\n"); + } + + { + thrust::device_vector d_out(count); + + auto smem_layout = Layout, + Stride<_32, _1>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + smem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe 32x32 U16x8_STSM_T SUCCESS\n"); + } + + CUTLASS_TRACE_HOST("PASS"); +} +#endif diff --git a/test/unit/cute/hopper/tma_load.cu b/test/unit/cute/hopper/tma_load.cu new file mode 100644 index 0000000000..24f17fca62 --- /dev/null +++ b/test/unit/cute/hopper/tma_load.cu @@ -0,0 +1,495 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include +#include + +#include + +using namespace cute; + +template +struct SharedStorage +{ + cute::array_aligned> smem; + cute::uint64_t tma_load_mbar[1]; +}; + +// __grid_constant__ was introduced in CUDA 11.7. +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) +# define CUTE_GRID_CONSTANT_SUPPORTED +#endif + +// __grid_constant__ can be enabled only on SM70+ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) +# define CUTE_GRID_CONSTANT_ENABLED +#endif + +#if ! defined(CUTE_GRID_CONSTANT) +# if defined(CUTE_GRID_CONSTANT_SUPPORTED) && defined(CUTE_GRID_CONSTANT_ENABLED) +# define CUTE_GRID_CONSTANT __grid_constant__ +# else +# define CUTE_GRID_CONSTANT +# endif +#endif + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, + CUTE_GRID_CONSTANT TiledCopy const tma, + GmemLayout gmem_layout, SmemLayout smem_layout) +{ + assert(product_each(shape(gmem_layout)) == product_each(smem_layout.shape())); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Shared memory barriers use 64bits in SMEM for synchronization + uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); + +#if 0 + + // + // Read in trivially + // + + Tensor gA_in = make_tensor(make_gmem_ptr(g_in), gmem_layout); + + // Input gmem -> smem + for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { + sA(i) = gA_in(i); + } + __syncthreads(); + +#else + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor gA = tma.get_tma_tensor(shape(gmem_layout)); + + // + // Prepare the TMA_LOAD + // + + auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice + + Tensor tAgA = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N) + Tensor tAsA = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N) + +#if 0 + if (thread0()) { + print(" gA: "); print(gA.data()); print(" o "); print(gA.layout()); print("\n"); + print("tAgA: "); print(tAgA.data()); print(" o "); print(tAgA.layout()); print("\n"); + print(" sA: "); print(sA.data()); print(" o "); print(sA.layout()); print("\n"); + print("tAsA: "); print(tAsA.data()); print(" o "); print(tAsA.layout()); print("\n"); + } +#endif + + // + // Perform the TMA_LOAD + // + + // Group the TMA_M and TMA_N modes + Tensor tAgA_2 = group_modes<1,rank(tAgA)>(tAgA); // (TMA,Rest) + Tensor tAsA_TR = group_modes<1,rank(tAsA)>(tAsA); // (TMA,Rest) + static_assert(size<1>(tAsA_TR) == 1); + Tensor tAsA_2 = tAsA_TR(_,0); + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tAgA_2); ++stage) + { + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr int kTmaTransactionBytes = size(sA) * sizeof(T); + + if (threadIdx.x == 0) + { + /// Initialize shared memory barrier + tma_load_mbar[0] = 0; + cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/); + cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); + + copy(tma.with(tma_load_mbar[0]), tAgA_2(_,stage), tAsA_2); + } + __syncthreads(); + + /// Wait on the shared memory barrier until the phase bit flips from kPhaseBit value + constexpr int kPhaseBit = 0; + cute::wait_barrier(tma_load_mbar[0], kPhaseBit); + + #endif + + // + // Write out trivially + // + + Tensor gA_out = make_tensor(make_gmem_ptr(g_out), gmem_layout); + // Do the same slicing and grouping as sA + Tensor tAgA_out = cta_tma.partition_D(gA_out); // (TMA,TMA_M,TMA_N) + Tensor tAgA_2_out = group_modes<1,rank(tAgA_out)>(tAgA_out); // (TMA,Rest) + + // Output smem -> gmem + for (int i = threadIdx.x; i < size(tAsA_2); i += blockDim.x) { + tAgA_2_out(i,stage) = tAsA_2(i); + } + __syncthreads(); + } +} + +TEST(SM90_CuTe_Hopper, Tma_load_32x32_Col) +{ + using T = half_t; + Layout smem_layout = Layout, Stride<_1,_32>>{}; + Layout gmem_layout = smem_layout; + + thrust::host_vector h_in(size(gmem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_in.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_LOAD 32x32 ColMajor SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_load_32x32_Row) +{ + using T = half_t; + Layout smem_layout = Layout, Stride<_32,_1>>{}; + Layout gmem_layout = smem_layout; + + thrust::host_vector h_in(size(gmem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_in.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_LOAD 32x32 RowMajor SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN) +{ + using T = half_t; + auto smem_layout = GMMA::Layout_MN_SW128_Atom{}; + Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); + + thrust::host_vector h_in(size(gmem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_in.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_K) +{ + using T = half_t; + auto smem_layout = GMMA::Layout_K_SW128_Atom{}; + Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenRowMajor{}); + + thrust::host_vector h_in(size(gmem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_in.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_K_SW128_Atom SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN_Multi) +{ + using T = half_t; + auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}); + Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); + + thrust::host_vector h_in(size(gmem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_in.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN_Multi2) +{ + using T = half_t; + // Tile the GMMA::Layout atom in the K-mode first, then the M-mode to get a bigger box size + auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); + Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); + + thrust::host_vector h_in(size(gmem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_in.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN_Multi_Dyn) +{ + using T = half_t; + auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); + Layout gmem_layout = make_layout(make_shape(128, 128), GenColMajor{}); + + thrust::host_vector h_in(size(gmem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_in.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_load_32x32_Multimode) +{ + using T = half_t; + auto smem_layout = Layout, Stride<_32,_1>>{}; + Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenRowMajor{}); + + //auto smem_layout = Layout>{}; + //Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenColMajor{}); + + thrust::host_vector h_in(size(gmem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_in.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_load_Tensor_blocking) +{ + using T = half_t; + auto gmem_layout = make_shape(make_shape(336,40),make_shape(32,656)); // GMEM + auto cta_tile = make_shape(make_shape(_16{},_8{}),make_shape(_32{},_2{})); // GMEM Tiling: + // Take 16-elem from m0, 8-elem from m1, + // Take 32-elem from k0, 2-elem from k1 + auto smem_layout = make_layout(cta_tile); // Col-Major SMEM + + thrust::host_vector h_in(size(gmem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_in.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout, cta_tile, Int<1>{}); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_LOAD Tensor blocking SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_load_Tensor_blocking_2) +{ + using T = half_t; + auto gmem_layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM + auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling: + // Take 128-elem from m: m0 must divide 128, + // m-last may be predicated + // Take 32-elem from k0, 2-elem from k1 + auto smem_layout = make_layout(cta_tile); // Col-Major SMEM + + thrust::host_vector h_in(size(gmem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_in.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout, cta_tile, Int<1>{}); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_LOAD Tensor blocking 2 SUCCESS\n"); +} +#endif diff --git a/test/unit/cute/hopper/tma_store.cu b/test/unit/cute/hopper/tma_store.cu new file mode 100644 index 0000000000..448b7f9125 --- /dev/null +++ b/test/unit/cute/hopper/tma_store.cu @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include +#include + +#include + +using namespace cute; + +template +struct SharedStorage +{ + cute::array_aligned> smem; +}; + +// __grid_constant__ was introduced in CUDA 11.7. +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) +# define CUTE_GRID_CONSTANT_SUPPORTED +#endif + +// __grid_constant__ can be enabled only on SM70+ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) +# define CUTE_GRID_CONSTANT_ENABLED +#endif + +#if ! defined(CUTE_GRID_CONSTANT) +# if defined(CUTE_GRID_CONSTANT_SUPPORTED) && defined(CUTE_GRID_CONSTANT_ENABLED) +# define CUTE_GRID_CONSTANT __grid_constant__ +# else +# define CUTE_GRID_CONSTANT +# endif +#endif + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, + CUTE_GRID_CONSTANT TiledCopy const tma, + GmemLayout gmem_layout, SmemLayout smem_layout) +{ + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); + + // + // Read in trivially + // + + Tensor gA_in = make_tensor(make_gmem_ptr(g_in), gmem_layout); + + // Input gmem -> smem + for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { + sA(i) = gA_in(i); + } + + __syncthreads(); + +#if 0 + + // + // Write out trivially + // + + Tensor gA_out = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + // Output smem -> gmem + for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { + gA_out(i) = sA(i); + } + +#else + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor gA = tma.get_tma_tensor(shape(gmem_layout)); + + // + // Prepare the TMA_STORE + // + + auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice + + Tensor tAsA = cta_tma.partition_S(sA); + Tensor tAgA = cta_tma.partition_D(gA); + + // + // Perform the TMA_STORE + // + + if (threadIdx.x == 0) { + copy(tma, tAsA, tAgA); + } + +#endif +} + +TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Col) +{ + using T = half_t; + Layout smem_layout = Layout, Stride<_1,_32>>{}; + Layout gmem_layout = smem_layout; + + thrust::host_vector h_in(size(smem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_out.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_STORE 32x32 ColMajor SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Row) +{ + using T = half_t; + Layout smem_layout = Layout, Stride<_32,_1>>{}; + Layout gmem_layout = smem_layout; + + thrust::host_vector h_in(size(smem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_out.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_STORE 32x32 RowMajor SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN) +{ + using T = half_t; + auto smem_layout = GMMA::Layout_MN_SW128_Atom{}; + Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); + + thrust::host_vector h_in(size(smem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_out.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_K) +{ + using T = half_t; + auto smem_layout = GMMA::Layout_K_SW128_Atom{}; + Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenRowMajor{}); + + thrust::host_vector h_in(size(smem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_out.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_K_SW128_Atom SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN_Multi) +{ + using T = half_t; + auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}); + Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); + + thrust::host_vector h_in(size(smem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_out.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN_Multi2) +{ + using T = half_t; + // Tile the GMMA::Layout atom in the K-mode first, then the M-mode to get a bigger box size + auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); + Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); + + thrust::host_vector h_in(size(smem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_out.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN_Multi_Dyn) +{ + using T = half_t; + auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); + Layout gmem_layout = make_layout(make_shape(128, 128), GenColMajor{}); + + thrust::host_vector h_in(size(smem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_out.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); +} + +TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Multimode) +{ + using T = half_t; + auto smem_layout = Layout, Stride<_32,_1>>{}; + Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenRowMajor{}); + + //auto smem_layout = Layout>{}; + //Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenColMajor{}); + + thrust::host_vector h_in(size(smem_layout)); + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + Tensor gA = make_tensor(d_out.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tma, + gmem_layout, + smem_layout); + + thrust::host_vector h_out = d_out; + for (int i = 0; i < size(smem_layout); ++i) { + //printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); +} +#endif diff --git a/test/unit/cute/layout/CMakeLists.txt b/test/unit/cute/layout/CMakeLists.txt new file mode 100644 index 0000000000..d9e6548c8f --- /dev/null +++ b/test/unit/cute/layout/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_layout + layout_operator.cu + ) diff --git a/test/unit/cute/layout/layout_operator.cu b/test/unit/cute/layout/layout_operator.cu new file mode 100644 index 0000000000..6c44f5aaa0 --- /dev/null +++ b/test/unit/cute/layout/layout_operator.cu @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit tests Generic CuTe Layouts +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include "cutlass/matrix_coord.h" + +// Cute includes +#include +#include + +using namespace cutlass; +using namespace cute; + +namespace test { +namespace layout { + +template + struct Testbed { + + + Testbed() {} + + bool run() { + GenericLayout generic_layout; + Layout layout = Layout::packed({size<0>(generic_layout), size<1>(generic_layout)}); + + for (int m = 0; m < size<0>(generic_layout); m++) { + for (int n = 0; n < size<1>(generic_layout); n++) { + if (generic_layout(m, n) != layout({m, n})) return false; + } + } + + return true; + } + }; + +} +} + +////////////////////////////////////////////////////////////////////////// +// Test Generic CuTe Layouts +////////////////////////////////////////////////////////////////////////// + +/// Canonical Layouts + +TEST(GenericLayout, ColumnMajor) { + using GenericLayout = cute::Layout, Stride<_1, _8>>; + using Layout = cutlass::layout::ColumnMajor; + + test::layout::Testbed testbed; + + EXPECT_TRUE(testbed.run()); +} +////////////////////////////////////////////////////////////////////////// + +TEST(GenericLayout, RowMajor) { + using GenericLayout = cute::Layout, Stride<_4, _1>>; + using Layout = cutlass::layout::RowMajor; + + test::layout::Testbed testbed; + + EXPECT_TRUE(testbed.run()); +} +////////////////////////////////////////////////////////////////////////// + + +/// Swizzle Shared Memory layouts + +TEST(GenericLayout, RowMajorTensorOpMultiplicandCrosswise) { + + using GenericLayout = decltype( + composition( + Swizzle<3,3,3>{}, + Layout, Stride<_64, _1>>{}) + ); + + using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + test::layout::Testbed testbed; + + EXPECT_TRUE(testbed.run()); +} +////////////////////////////////////////////////////////////////////////// + +TEST(GenericLayout, ColumnMajorTensorOpMultiplicandCongruous) { + + using GenericLayout = decltype( + composition( + Swizzle<3,3,4>{}, + Layout>{}) + ); + + using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + + + test::layout::Testbed testbed; + + EXPECT_TRUE(testbed.run()); +} +////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index c9ebe87f00..2803a8965c 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -41,6 +41,8 @@ add_custom_target( cutlass_test_unit_gemm_device_tensorop_planar_complex cutlass_test_unit_gemm_device_sparse_tensorop_sm80 cutlass_test_unit_gemv_device + cutlass_test_unit_gemm_device_tensorop_sm90 + cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90 ) add_custom_target( @@ -58,6 +60,14 @@ add_custom_target( test_unit_gemm_device_tensorop_planar_complex test_unit_gemm_device_sparse_tensorop_sm80 test_unit_gemv_device + test_unit_gemm_device_tensorop_sm90 +) + +add_custom_target( + cutlass_test_unit_gemm_device_sm90 + DEPENDS + cutlass_test_unit_gemm_device_tensorop_sm90 + cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90 ) cutlass_test_unit_add_executable( @@ -78,7 +88,7 @@ cutlass_test_unit_add_executable( simt_cgemm_nt_sm50.cu simt_cgemm_tn_sm50.cu simt_cgemm_tt_sm50.cu - + simt_qgemm_nn_sm50.cu simt_qgemm_nt_sm50.cu simt_qgemm_tn_sm50.cu @@ -88,33 +98,48 @@ cutlass_test_unit_add_executable( simt_dgemm_nt_sm50.cu simt_dgemm_tn_sm50.cu simt_dgemm_tt_sm50.cu - + simt_hgemm_nn_sm50.cu simt_hgemm_nt_sm50.cu simt_hgemm_tn_sm50.cu simt_hgemm_tt_sm50.cu - + simt_igemm_nn_sm50.cu simt_igemm_nt_sm50.cu simt_igemm_tn_sm50.cu simt_igemm_tt_sm50.cu - + simt_int8_igemm_sm61_sliced_k.cu simt_int8_igemm_sm61.cu - + simt_sgemm_nn_sm50.cu simt_sgemm_nt_sm50.cu simt_sgemm_tn_sm50.cu simt_sgemm_tt_sm50.cu - + simt_zgemm_nn_sm50.cu simt_zgemm_nt_sm50.cu simt_zgemm_tn_sm50.cu simt_zgemm_tt_sm50.cu - + gemm_splitk_simt_sm50.cu ) +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_simt_3x + + BATCH_SOURCES ON + BATCH_SIZE 4 + + + sm50_gemm_f32_f32_f32_simt.cu + sm80_gemm_f32_f32_f32_simt.cu + sm50_gemm_f64_f64_f64_simt.cu + sm80_gemm_f64_f64_f64_simt.cu + sm61_gemm_s8_s8_s32_simt.cu +) + + cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_sm70 @@ -209,6 +234,51 @@ cutlass_test_unit_add_executable( gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu ) +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_f32_sm80_3x + + sm80_gemm_s8_s8_s32_tensor_op.cu + sm80_gemm_f16_f16_f32_tensor_op_f32.cu + sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm90 + + BATCH_SOURCES ON + BATCH_SIZE 4 + + sm90_gemm_f16_f16_f16_tensor_op.cu + sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu + sm90_gemm_s8_s8_s8_tensor_op_s32.cu + sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu + sm90_gemm_f32_f32_f32_tensor_op_f32.cu +) + +# Alignment tests +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_alignx_sm90 + + BATCH_SOURCES ON + BATCH_SIZE 4 + sm90_gemm_f16_f16_f16_alignx_tensor_op.cu + sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu + sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu + sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu +) + + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90 + + BATCH_SOURCES ON + BATCH_SIZE 4 + + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu +) + cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_f32_tf32_sm80 @@ -226,6 +296,7 @@ cutlass_test_unit_add_executable( gemm_f32n_f32n_f32t_tensor_op_f32_sm80.cu gemm_f32n_f32n_f32t_tensor_op_bf16_f32_sm80.cu + sm80_gemm_f16_f16_f32_tensor_op_f32.cu ) cutlass_test_unit_add_executable( @@ -247,6 +318,9 @@ cutlass_test_unit_add_executable( # SM90 device level tests gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu + + sm80_gemm_f64_f64_f64_tensor_op_f64.cu + gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu diff --git a/test/unit/gemm/device/default_gemm_configuration.hpp b/test/unit/gemm/device/default_gemm_configuration.hpp new file mode 100644 index 0000000000..f84e92977f --- /dev/null +++ b/test/unit/gemm/device/default_gemm_configuration.hpp @@ -0,0 +1,1343 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/layout/layout.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +namespace cutlass { +namespace gemm { +namespace device { +using namespace cute; + +// This type is only intended to demonstrate porting 2.x kernels to 3.0 +template< + class OperatorClass, class ArchTag, + class ElementA, class LayoutA, + class ElementB, class LayoutB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types { + static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists."); +}; + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct DefaultGemm_TensorOpSm80_OperandA; + +template +struct DefaultGemm_TensorOpSm80_OperandB; + +// +// F16: 128-by-128-by-64 +// + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride<_64, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride< _1,_64>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _1,_16>>{}, + Layout>{})); +}; + +// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// +// F16: 128-by-128-by-32 (small k-block) +// + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2,3,3>{}, + Layout, + Stride<_32, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>{})); +}; + +} + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere MMA F32F16 +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + half_t, LayoutA, + half_t, LayoutB, + float, LayoutC, + float> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Layout>>; // 1x2x1 value group for 16x16x16 MMA and LDSM + + // A + static constexpr int kAlignmentA = 8; + using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< + half_t, LayoutA, kAlignmentA, 32>; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 8; + using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< + half_t, LayoutB, kAlignmentB, 32>; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + half_t, TagToStrideA_t, + half_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// +// TF32: 128-by-128-by-kblock (kBlock = 16, 32) +// + +/// Operand A - Row-major (K-major) (kBlock = 32) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,2,3>{}, + Layout, + Stride<_32, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); +}; + +/// Operand A - Row-major (K-major) (kBlock = 16) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2,2,3>{}, + Layout, + Stride<_16, _1>>{})); + using SmemCopyAtom = Copy_Atom; + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,2,3>{}, + Layout, + Stride< _1,_32>>{})); + using SmemCopyAtom = Copy_Atom, tfloat32_t>; + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _1,_16>>{}, + Layout>{})); +}; + +// Because the TF32 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +} + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere MMA F32TF32 +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + tfloat32_t, LayoutA, + tfloat32_t, LayoutB, + float, LayoutC, + float> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout, Stride<_2, _1, _1>>, // 2x2x1 thread group + Layout>>; // 1x2x1 value group for 16x16x8 and LDSM + + // A + static constexpr int kAlignmentA = 4; + using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< + tfloat32_t, LayoutA, kAlignmentA, 32>; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 4; + using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< + tfloat32_t, LayoutB, kAlignmentB, 32>; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + tfloat32_t, TagToStrideA_t, + tfloat32_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::ColumnMajor, + int32_t, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _64>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Layout>>; // 1x2x1 value group for 16x16x32 and LDSM + + // A (M,K) K-major + using SmemLayoutAtomA = decltype( + composition( + Swizzle<2,4,3>{}, + Layout, + Stride<_64, _1>>{})); + static constexpr int kAlignmentA = 16; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, int8_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>>{})); + // LDS.32- or LDSM-based copy atom + // using SmemCopyAtomA = Copy_Atom; + using SmemCopyAtomA = Copy_Atom; // LDSM works + + // B (N,K) K-major + using SmemLayoutAtomB = decltype( + composition( + Swizzle<2,4,3>{}, + Layout, + Stride<_64, _1>>{})); + static constexpr int kAlignmentB = 16; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, int8_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>>{})); + + // LDS.32- or LDSM-based copy atom + // using SmemCopyAtomB = Copy_Atom; + using SmemCopyAtomB = Copy_Atom; // LDSM works + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + int8_t, TagToStrideA_t, + int8_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// SIMT TWO STAGE /////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct DefaultGemm_Simt_OperandA; + +/////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultGemm_Simt_OperandA +{ + using SmemLayoutAtom = Layout, + Stride< _1,_128>>; + + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); +}; + +template +struct DefaultGemm_Simt_OperandA +{ + using SmemLayoutAtom = Layout, + Stride< _1,Int<128 + 4>>>; // Padded + + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + Layout, + Stride< _8, _1>>{}, + Layout>{})); + +}; + +template +struct DefaultGemm_Simt_OperandB; + +template +struct DefaultGemm_Simt_OperandB + : DefaultGemm_Simt_OperandA {}; + +template +struct DefaultGemm_Simt_OperandB + : DefaultGemm_Simt_OperandA {}; + +} // end namespace detail + +// SIMT Two Stage +template < + class ArchTag, + class ElementA, class LayoutA, + class ElementB, class LayoutB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _8>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>>; + + // A + static constexpr int kAlignmentA = 1; + using DefaultOperandA = detail::DefaultGemm_Simt_OperandA; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 1; + using DefaultOperandB = detail::DefaultGemm_Simt_OperandB; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + + +// +// DP4A - int8 Proof-of-concept +// + +// SIMT Two Stage TN - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + // NOTE: permuting MMA M mode lets us generate 128b smem loads (LDS.128) but has worst case bank conflicts + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; // Tile of atoms (threads) + + // A (M,K) K-major + using ElementA = int8_t; + // 40% from regular M and N major layout + // using SmemLayoutAtomA = Layout, + // Stride< _1,_128>>; + // 80% from interleaved layouts + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 4; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) K-major + using ElementB = int8_t; + // 40% from regular M and N major layout + // using SmemLayoutAtomB = Layout, + // Stride< _1,_128>>; + // 80% from interleaved layouts + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 4; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage NN - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::ColumnMajor, + int8_t, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + + using DispatchPolicy = MainloopSm70TwoStage; + + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) M-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // B (N,K) K-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 4; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilouge + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage NT - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::ColumnMajor, + int8_t, cutlass::layout::RowMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) M-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // B (N,K) N-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage TT - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::RowMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) K-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 4; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) N-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// SIMT MULTI STAGE ////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage NT +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, + Layout>, + Tile,Layout<_2,_16>,Underscore>>; + + // A (M,K) M-major + using SmemLayoutAtomA = Layout>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout>{}, + Layout>{})); + + // B (N,K) N-major + using SmemLayoutAtomB = Layout>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage TN +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>>; + + // A (M,K) K-major + using SmemLayoutAtomA = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride<_16, _1>>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride<_16, _1>>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage NN +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, + Layout>, + Tile,Underscore,Underscore>>; + + // A (M,K) M-major + using SmemLayoutAtomA = Layout>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout>{}, + Layout>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride<_16, _1>>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage TT +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, + Layout>, + Tile,Underscore>>; + + // A (M,K) K-major + using SmemLayoutAtomA = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride<_16, _1>>{})); + + // B (N,K) N-major + using SmemLayoutAtomB = Layout>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA TN (K-Major A and K-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Layout>, // Val layout + Tile,Layout<_2,_16>,Underscore>>; // Mode permutations + + // A (M,K) K-Major + using SmemLayoutAtomA = decltype( + composition(SwizzleXor<2,0,2>{}, + Layout, + Stride<_1, _4>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // B (N,K) K-Major + using SmemLayoutAtomB = decltype( + composition(SwizzleXor<2,0,2>{}, + Layout, + Stride<_1, _4>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; + +/* + using EpilogueOutputOp = epilogue::collective::Epilogue< + epilogue::thread::LinearCombination, + Layout, + Stride< _1,_64>>, // SMEM layout + Copy_Atom,double>, // R2S with tiled_mma layout + decltype(make_tiled_copy(Copy_Atom,double>{},// S2R + Layout, + Stride< _1,_16>>{}, // Thread layout + Layout>{})), // Value layout + Copy_Atom,double> // R2G with S2R_dst layout + >; +*/ +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA NN (M-Major A and K-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Layout>, // Val layout + Tile,Layout<_2,_16>,Underscore>>; // Mode permutations + + // A (M,K) M-Major + using SmemLayoutAtomA = decltype( + composition(SwizzleXor<2,2,0>{}, + Layout, + Stride< _1,_16>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // B (N,K) K-Major + using SmemLayoutAtomB = decltype( + composition(SwizzleXor<2,0,2>{}, + Layout, + Stride<_1, _4>>{}));// N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA NT (M-Major A and N-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Layout>, // Val layout + Tile,Layout<_2,_16>,Underscore>>; // Mode permutations + + // A (M,K) M-Major + using SmemLayoutAtomA = decltype( + composition(SwizzleXor<2,2,0>{}, + Layout, + Stride< _1,_16>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // B (N,K) N-Major + using SmemLayoutAtomB = decltype( + composition(SwizzleXor<2,2,0>{}, + Layout, + Stride< _1,_16>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA TT (K-Major A and N-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::RowMajor, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Layout>, // Val layout + Tile,Layout<_2,_16>,Underscore>>; // Mode permutations + + // A (M,K) K-Major + using SmemLayoutAtomA = decltype( + composition(SwizzleXor<2,0,2>{}, + Layout, + Stride<_1, _4>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // B (N,K) N-Major + using SmemLayoutAtomB = decltype( + composition(SwizzleXor<2,2,0>{}, + Layout, + Stride< _1,_16>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Hopper fp64 MMA TN +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm90, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) K-major + using SmemLayoutAtomA = decltype( + make_ordered_layout(Shape<_128,_16>{}, + Step < _2, _1>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = decltype( + make_ordered_layout(Shape<_64,_16>{}, + Step < _2, _1>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass diff --git a/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu index e3fa731cce..cc9430350e 100644 --- a/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu +++ b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu @@ -50,7 +50,7 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -193,6 +193,6 @@ TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x8_16x32x8) ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu index 232c3bb744..e2931b0203 100644 --- a/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu @@ -50,7 +50,7 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -247,6 +247,6 @@ TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_32x32x8) { ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu index bfba315d51..eb011e4c53 100644 --- a/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu +++ b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu @@ -50,7 +50,7 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -191,7 +191,7 @@ TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x16_32x16x1 ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu index 2c97b6bb74..c0333e7c6a 100644 --- a/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu @@ -50,7 +50,7 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -299,7 +299,7 @@ TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x16_32x32x16) { ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu b/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu index 0a5f0497aa..62cb15dd77 100644 --- a/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu @@ -46,7 +46,7 @@ #include "testbed.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -220,4 +220,4 @@ TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x128x16_32x64x16_16x8x4) ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu b/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu index e44b49590e..881d81c8b9 100644 --- a/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu @@ -46,7 +46,7 @@ #include "testbed.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -220,4 +220,4 @@ TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x128x16_32x64x16_16x8x4) } ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // if (CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // if (CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp new file mode 100644 index 0000000000..24a9e242fa --- /dev/null +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -0,0 +1,717 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" + +#include "testbed_utils.h" + +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/gemm.h" + +#include "cute/int_tuple.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail{ + +template +struct TestbedImpl { + // Kernel data types + using ElementA = typename Gemm::GemmKernel::ElementA; + using StrideA = typename Gemm::GemmKernel::StrideA; + using ElementB = typename Gemm::GemmKernel::ElementB; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ElementC = typename Gemm::GemmKernel::ElementC; + using StrideC = typename Gemm::GemmKernel::StrideC; + using ElementD = typename Gemm::GemmKernel::ElementD; + using StrideD = typename Gemm::GemmKernel::StrideD; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::CollectiveEpilogue::ElementCompute; + using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + // Looks at Cute Stride to check Row / Column Major + template + static constexpr bool is_row_or_col_major(){ + int stride_0 = int(cute::size<0>(Stride{})); + int stride_1 = int(cute::size<1>(Stride{})); + int depth = cute::depth(Stride{}); + return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); + } + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagA = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); + using LayoutTagB = decltype(cutlass::gemm::detail::stride_to_layout_tag_B()); + using LayoutTagC = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); + using LayoutTagD = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); + using LayoutTagPackedVector = cutlass::layout::PackedVectorLayout; + + /// Initialization + StrideA stride_a; + StrideB stride_b; + StrideC stride_c; + StrideD stride_d; + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + uint32_t sm_count; + + // Used to force multi-wave tests for persistent kernel schedules + constexpr static int MaxSmCount = 16; + + // + // Methods + // + + TestbedImpl( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): + stride_factor_A(typename LayoutTagA::Stride()), + stride_factor_B(typename LayoutTagB::Stride()), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + TestbedImpl( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + stride_factor_C(stride_factor_C_), + stride_factor_D(stride_factor_D_), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_c = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_d = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M * L, K); + auto c_coord = cutlass::make_Coord(M * L, N); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N * L); + + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + tensor_C.host_view().at(cutlass::make_Coord(0, 0)) = ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta + ) { + auto [M, N, K, L] = problem_shape_MNKL; + + tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + } + + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + } + + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L + << ", alpha: " << float(alpha) << ", beta: " << float(beta) << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\n\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + ProblemShapeType problem_size, + ElementScalar alpha, + ElementScalar beta + ) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + auto C = cute::make_tensor(tensor_C.host_data(), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + auto D = cute::make_tensor(reference_D.host_data(), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D) + > + epilogue_params{ + alpha, beta, + C, D + }; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + return compare_reference( + problem_shape_MNKL, alpha, beta + ); + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() { + // + // Determine SMEM requirements and waive if not satisfied + // + + int smem_size = Gemm::GemmKernel::SharedStorageSize; + + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + this->sm_count = properties.multiProcessorCount; + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + bool profile( + ProblemShapeType problem_size, + int iterations, + Gemm& gemm_op, + typename Gemm::Arguments& arguments, + cutlass::device_memory::allocation& workspace) { + int M = cute::size<0>(problem_size); + int N = cute::size<1>(problem_size); + int K = cute::size<2>(problem_size); + int L = 1; + if constexpr(cute::rank(ProblemShapeType{}) == 4) { + L = cute::size<3>(problem_size); + } + + + cutlass::Status status; + // + // Run the GEMM + // + cudaError_t result; + + for (int iter = 0; iter < iterations; ++iter) { + status = gemm_op(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + return false; + } + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + return true; + } + + /// Executes one test + bool run( + ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + int iterations = 20 + ) { + // Fail test if insufficient CUDA device + if (!sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + this->sm_count = min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = this->sm_count; + } + else { + this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = this->sm_count; + } + + // DefaultEpilogue + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + tensor_A.device_data(), + stride_a, + tensor_B.device_data(), + stride_b, + {tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d, {alpha, beta}}, + hw_info + }; + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + + if (profiling) { + return profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify( + problem_size, alpha, beta + ); + if (!passed) { + std::cout << "Error : Failed : with alpha: " << float(alpha) << ", beta: " << float(beta) + << "\n"; + } + + return passed; + } + } +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Testbed { + + using TestBedImplementation = typename detail::TestbedImpl; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::CollectiveEpilogue::ElementCompute; + using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; + using LayoutTagA = typename TestBedImplementation::LayoutTagA; + using LayoutTagB = typename TestBedImplementation::LayoutTagB; + using LayoutTagC = typename TestBedImplementation::LayoutTagC; + using LayoutTagD = typename TestBedImplementation::LayoutTagD; + + // Detail Implementation + TestBedImplementation impl_; + + // + // Methods + // + Testbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImplementation::kDefaultSeed) + : impl_(init_A_, init_B_, init_C_, seed_) {} + + Testbed( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImplementation::kDefaultSeed) + : impl_(stride_factor_A_, + stride_factor_B_, + stride_factor_C_, + stride_factor_D_, + init_A_, + init_B_, + init_C_, + seed_) {} + + /// Executes one test + bool run( + typename TestBedImplementation::ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + int iterations = 20 + ) { + return impl_.run( + problem_size, alpha, beta, profiling, iterations + ); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAll() { + using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + if constexpr (std::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + Testbed testbed; + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(0) + ); + + if (!passed) { + return false; + } + } + } + } + + // if we do support batched GEMM, just run one test on it to save on test time + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(0) + ); + + if (!passed) { + return false; + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmPerf(int iterations = 20) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalar = ElementAccumulator; + bool passed = true; + + std::vector problem_size_m = { 4608 }; + std::vector problem_size_n = { 4608 }; + std::vector problem_size_k = { 8192 }; + + Testbed testbed; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(0), + true, + iterations + ); + + if (!passed) { + return false; + } + } + } + } + + + // if we do support batched GEMM, just run it once + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{problem_size_m[0], problem_size_n[0], problem_size_k[0], /* l */ 4}; + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(0), + true, + iterations + ); + + if (!passed) { + return false; + } + } + + return passed; +} + + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu index b8880c56d2..9a115493c4 100644 --- a/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu @@ -48,7 +48,7 @@ #include "testbed_symm_universal.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -132,4 +132,4 @@ TEST(SM90_Device_Hemm_cf64h_cf64n_rs_u_tensor_op_f64, 64x64x16_32x32x16) { ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu index 135ab9430f..fbc4efdb93 100644 --- a/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu @@ -46,7 +46,7 @@ #include "testbed_rank2k_universal.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -146,4 +146,4 @@ TEST(SM90_Device_Her2k_cf64c_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu index 28e4c4d64b..114a20cf9c 100644 --- a/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu @@ -46,7 +46,7 @@ #include "testbed_rank_k_universal.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// // HERK operator on CUBLAS_OP_C (row-major + conj) input layouts @@ -90,4 +90,4 @@ TEST(SM90_Device_Herk_cf64h_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { } ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/device/multistage_testbed.h b/test/unit/gemm/device/multistage_testbed.h index 252fe7b37d..681e051e31 100644 --- a/test/unit/gemm/device/multistage_testbed.h +++ b/test/unit/gemm/device/multistage_testbed.h @@ -58,6 +58,11 @@ namespace device { template struct MultistageTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; diff --git a/test/unit/gemm/device/multistage_testbed_interleaved.h b/test/unit/gemm/device/multistage_testbed_interleaved.h index c5edc0ed26..5f332069d1 100644 --- a/test/unit/gemm/device/multistage_testbed_interleaved.h +++ b/test/unit/gemm/device/multistage_testbed_interleaved.h @@ -59,6 +59,9 @@ namespace device { template struct MultistageInterleavedTestbed { + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; using ElementAccumulator = typename Gemm::ElementAccumulator; using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; @@ -110,12 +113,49 @@ struct MultistageInterleavedTestbed { return true; } + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerMultiprocessor < smem_size) { + return false; + } + + return true; + } + /// Executes one test bool run( cutlass::gemm::GemmCoord problem_size, ElementCompute alpha = ElementCompute(1), ElementCompute beta = ElementCompute(0)) { + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + // // Allocate the GEMM workspace // diff --git a/test/unit/gemm/device/sm50_gemm_f32_f32_f32_simt.cu b/test/unit/gemm/device/sm50_gemm_f32_f32_f32_simt.cu new file mode 100644 index 0000000000..f7a18bc667 --- /dev/null +++ b/test/unit/gemm/device/sm50_gemm_f32_f32_f32_simt.cu @@ -0,0 +1,135 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "default_gemm_configuration.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_Device_Gemm_f32n_f32n_f32n_simt_f32, 128x128x64_64x64x64) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm50, + float, cutlass::layout::ColumnMajor, + float, cutlass::layout::ColumnMajor, + float, cutlass::layout::ColumnMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_Device_Gemm_f32n_f32t_f32n_simt_f32, 128x128x64_64x64x64) { + + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm50, + float, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::ColumnMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_Device_Gemm_f32t_f32n_f32n_simt_f32, 128x128x64_64x64x64) { + + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm50, + float, cutlass::layout::RowMajor, + float, cutlass::layout::ColumnMajor, + float, cutlass::layout::ColumnMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_Device_Gemm_f32t_f32t_f32n_simt_f32, 128x128x64_64x64x64) { + + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm50, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::ColumnMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/sm50_gemm_f64_f64_f64_simt.cu b/test/unit/gemm/device/sm50_gemm_f64_f64_f64_simt.cu new file mode 100644 index 0000000000..421072fee9 --- /dev/null +++ b/test/unit/gemm/device/sm50_gemm_f64_f64_f64_simt.cu @@ -0,0 +1,134 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "default_gemm_configuration.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_Device_Gemm_f64n_f64n_f64n_simt_f64, 128x128x64_64x64x64) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm50, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +TEST(SM50_Device_Gemm_f64n_f64t_f64n_simt_f64, 128x128x64_64x64x64) { + + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm50, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_Device_Gemm_f64t_f64n_f64n_simt_f64, 128x128x64_64x64x64) { + + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm50, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_Device_Gemm_f64t_f64t_f64n_simt_f64, 128x128x64_64x64x64) { + + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm50, + double, cutlass::layout::RowMajor, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/sm61_gemm_s8_s8_s32_simt.cu b/test/unit/gemm/device/sm61_gemm_s8_s8_s32_simt.cu new file mode 100644 index 0000000000..ba6456b57a --- /dev/null +++ b/test/unit/gemm/device/sm61_gemm_s8_s8_s32_simt.cu @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "default_gemm_configuration.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +//#if defined(CUTLASS_ARCH_MMA_SM61_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM61_Device_Gemm_s8n_s8n_s32n_simt_s32, 128x128x64_64x64x64) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm50, + int8_t, cutlass::layout::ColumnMajor, + int8_t, cutlass::layout::ColumnMajor, + int32_t, cutlass::layout::ColumnMajor, + int32_t>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM61_Device_Gemm_s8n_s8t_s32n_simt_s32, 128x128x64_64x64x64) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm50, + int8_t, cutlass::layout::ColumnMajor, + int8_t, cutlass::layout::RowMajor, + int32_t, cutlass::layout::ColumnMajor, + int32_t>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM61_Device_Gemm_s8t_s8n_s32n_simt_s32, 128x128x64_64x64x64) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm50, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::ColumnMajor, + int32_t, cutlass::layout::ColumnMajor, + int32_t>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM61_Device_Gemm_s8t_s8t_s32n_simt_s32, 128x128x64_64x64x64) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm50, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::RowMajor, + int32_t, cutlass::layout::ColumnMajor, + int32_t>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//#endif // #if defined(CUTLASS_ARCH_MMA_SM61_SUPPORTED) diff --git a/test/unit/gemm/device/sm80_gemm_f16_f16_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm80_gemm_f16_f16_f32_tensor_op_f32.cu new file mode 100644 index 0000000000..40f7cdb29a --- /dev/null +++ b/test/unit/gemm/device/sm80_gemm_f16_f16_f32_tensor_op_f32.cu @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "default_gemm_configuration.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +//#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +#if 1 +TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32_3x, 128x128x32_64x64x32) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::half_t, cutlass::layout::RowMajor, + cutlass::half_t, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} +#endif +///////////////////////////////////////////////////////////////////////////////////////////////// +#if 1 +TEST(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32_3x, 128x128x32_64x64x32) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::half_t, cutlass::layout::ColumnMajor, + cutlass::half_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32_3x, 128x128x32_64x64x32) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::half_t, cutlass::layout::ColumnMajor, + cutlass::half_t, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32_3x, 128x128x32_64x64x32) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::half_t, cutlass::layout::RowMajor, + cutlass::half_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} +#endif +///////////////////////////////////////////////////////////////////////////////////////////////// + +//#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/sm80_gemm_f32_f32_f32_simt.cu b/test/unit/gemm/device/sm80_gemm_f32_f32_f32_simt.cu new file mode 100644 index 0000000000..a7c6b522c5 --- /dev/null +++ b/test/unit/gemm/device/sm80_gemm_f32_f32_f32_simt.cu @@ -0,0 +1,135 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "default_gemm_configuration.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f32n_f32n_f32n_simt_f32, 128x128x64_64x64x64) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm80, + float, cutlass::layout::ColumnMajor, + float, cutlass::layout::ColumnMajor, + float, cutlass::layout::ColumnMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f32n_f32t_f32n_simt_f32, 128x128x64_64x64x64) { + + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm80, + float, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::ColumnMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f32t_f32n_f32n_simt_f32, 128x128x64_64x64x64) { + + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm80, + float, cutlass::layout::RowMajor, + float, cutlass::layout::ColumnMajor, + float, cutlass::layout::ColumnMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f32t_f32t_f32n_simt_f32, 128x128x64_64x64x64) { + + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm80, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::ColumnMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/sm80_gemm_f64_f64_f64_simt.cu b/test/unit/gemm/device/sm80_gemm_f64_f64_f64_simt.cu new file mode 100644 index 0000000000..274b30cf82 --- /dev/null +++ b/test/unit/gemm/device/sm80_gemm_f64_f64_f64_simt.cu @@ -0,0 +1,134 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "default_gemm_configuration.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64n_f64n_f64n_simt_f64, 128x128x64_64x64x64) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm80, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Gemm_f64n_f64t_f64n_simt_f64, 128x128x64_64x64x64) { + + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm80, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64t_f64n_f64n_simt_f64, 128x128x64_64x64x64) { + + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm80, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64t_f64t_f64n_simt_f64, 128x128x64_64x64x64) { + + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassSimt, cutlass::arch::Sm80, + double, cutlass::layout::RowMajor, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/sm80_gemm_f64_f64_f64_tensor_op_f64.cu b/test/unit/gemm/device/sm80_gemm_f64_f64_f64_tensor_op_f64.cu new file mode 100644 index 0000000000..e53a8e8122 --- /dev/null +++ b/test/unit/gemm/device/sm80_gemm_f64_f64_f64_tensor_op_f64.cu @@ -0,0 +1,98 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "default_gemm_configuration.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +//#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64n_f64t_f64n_tensor_op_f64, 128x128x64_64x64x64) { + + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_f64t_f64n_f64n_tensor_op_f64, 128x128x64_64x64x64) { + + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// #endif diff --git a/test/unit/gemm/device/sm80_gemm_s8_s8_s32_tensor_op.cu b/test/unit/gemm/device/sm80_gemm_s8_s8_s32_tensor_op.cu new file mode 100644 index 0000000000..d53cf54267 --- /dev/null +++ b/test/unit/gemm/device/sm80_gemm_s8_s8_s32_tensor_op.cu @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "default_gemm_configuration.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +//#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(DISABLED_SM80_Device_Gemm_s8n_s8n_s32n_tensor_op_s32, 128x128x32_64x64x64) { + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(DISABLED_SM80_Device_Gemm_s8n_s8t_s32n_tensor_op_s32, 128x128x32_64x64x64) { + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x128x32_64x64x64) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::ColumnMajor, + int32_t, cutlass::layout::ColumnMajor, + int32_t>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(DISABLED_SM80_Device_Gemm_s8t_s8t_s32n_tensor_op_s32, 128x128x32_64x64x64) { + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu new file mode 100644 index 0000000000..14654c781f --- /dev/null +++ b/test/unit/gemm/device/sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu @@ -0,0 +1,135 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "default_gemm_configuration.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + + +//#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_tf32n_tf32n_f32n_tensor_op_f32, 128x128x32_64x64x64) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::tfloat32_t, cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_tf32n_tf32t_f32n_tensor_op_f32, 128x128x32_64x64x64) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::tfloat32_t, cutlass::layout::ColumnMajor, + cutlass::tfloat32_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_tf32t_tf32n_f32n_tensor_op_f32, 128x128x32_64x64x64) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::tfloat32_t, cutlass::layout::RowMajor, + cutlass::tfloat32_t, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM80_Device_Gemm_tf32t_tf32t_f32n_tensor_op_f32, 128x128x32_64x64x64) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::tfloat32_t, cutlass::layout::RowMajor, + cutlass::tfloat32_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu new file mode 100644 index 0000000000..9fbbd86286 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu @@ -0,0 +1,188 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16t_bf16t_bf16n_align8_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 8, + cutlass::bfloat16_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 4, + cutlass::bfloat16_t, LayoutB, 4, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_align2_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 2, + cutlass::bfloat16_t, LayoutB, 2, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16n_bf16n_bf16n_align8_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 8, + cutlass::bfloat16_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu new file mode 100644 index 0000000000..d3983e4d10 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu @@ -0,0 +1,187 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16t_bf16t_bf16n_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 8, + cutlass::bfloat16_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 8, + cutlass::bfloat16_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 8, + cutlass::bfloat16_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16n_bf16n_bf16n_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 8, + cutlass::bfloat16_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu new file mode 100644 index 0000000000..0ee526bec0 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu @@ -0,0 +1,449 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// TT ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// TN ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// NT ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// NN ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op.cu new file mode 100644 index 0000000000..4fea99ab22 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op.cu @@ -0,0 +1,1077 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/epilogue.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32, 128x128x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32, 64x64x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32, 128x128x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32, 64x64x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f32, 128x128x32) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f32, 64x64x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f32, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f32, 128x128x32) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f32, 64x64x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f16, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f16, 128x128x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f16, 64x64x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16, 128x128x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16, 64x64x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f16, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f16, 128x128x32) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f16, 64x64x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f16, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f16, 128x128x32) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f16, 64x64x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_Epilogue, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_1,_64>>>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,Shape<_64,_16>>, + Copy_Atom>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_Epilogue, 128x64x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_128,_64,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits::value>, Layout,_64>,Stride,_64>>>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,Shape<_128,_8>>, + Copy_Atom>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_Epilogue, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits::value>, Layout>,Stride<_64,Stride<_1,_4096>>>>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,Shape<_8,_128>>, + Copy_Atom>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_Epilogue, 128x64x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + cutlass::half_t, + Shape<_128,_64,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_64,_1>>>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,Shape<_16,_64>>, + Copy_Atom>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_Epilogue, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_1,_64>>>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,Shape<_64,_16>>, + Copy_Atom>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_Epilogue, 128x64x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_128,_64,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits::value>, Layout,_64>,Stride,_64>>>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,Shape<_128,_8>>, + Copy_Atom>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_Epilogue, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits::value>, Layout>,Stride<_64,Stride<_1,_4096>>>>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,Shape<_8,_128>>, + Copy_Atom>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_Epilogue, 128x64x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_128,_64,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_64,_1>>>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,Shape<_16,_64>>, + Copy_Atom>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu new file mode 100644 index 0000000000..16466329f5 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu @@ -0,0 +1,582 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 2x2x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x2x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x2x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 4x1x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_4x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_4x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_4x1x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_4x1x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 1x4x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_1x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_1x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_1x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_1x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 2x4x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTma + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu new file mode 100644 index 0000000000..378315d6cb --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu @@ -0,0 +1,582 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 2x2x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x2x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x2x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 4x1x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_4x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_4x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_4x1x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_4x1x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 1x4x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_1x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_1x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_1x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_1x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 2x4x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu new file mode 100644 index 0000000000..c7d814b858 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu @@ -0,0 +1,1018 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/epilogue.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x1x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_64>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x1x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_64>; + using ClusterShape_MNK = Shape<_2,_1,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_64>; + using ClusterShape_MNK = Shape<_1,_2,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_4x1x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_64>; + using ClusterShape_MNK = Shape<_4,_1,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x4x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_64>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x4x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_64>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_4x4x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_64>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1x1x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2x1x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_1,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_1,_2,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_4x1x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_4,_1,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1x4x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2x4x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_4x4x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_persistent_Epilogue, 64x128x64_2x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::half_t; + using ElementC = cutlass::half_t; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using PreSwizzleLayout = Layout,Stride<_1,_64>>; + using TileShapeS2R = Shape<_64,_16>; + + using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, + Copy_Atom>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_persistent_Epilogue, 128x64x64_2x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::half_t; + using ElementC = cutlass::half_t; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_64,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using PreSwizzleLayout = Layout,_64>,Stride,_64>>; + using TileShapeS2R = Shape<_128,_8>; + + using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, + Copy_Atom>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_persistent_Epilogue, 64x128x64_2x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::half_t; + using ElementC = cutlass::half_t; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_64,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using PreSwizzleLayout = Layout>,Stride<_64,Stride<_1,_4096>>>; + using TileShapeS2R = Shape<_8,_128>; + + using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, + Copy_Atom>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_persistent_Epilogue, 128x64x64_2x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::half_t; + using ElementC = cutlass::half_t; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_64,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using PreSwizzleLayout = Layout,Stride<_64,_1>>; + using TileShapeS2R = Shape<_16,_64>; + + using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, + Copy_Atom>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_persistent_Epilogue, 64x128x64_2x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + using ElementC = cutlass::half_t; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using PreSwizzleLayout = Layout,Stride<_1,_64>>; + using TileShapeS2R = Shape<_64,_16>; + + using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, + Copy_Atom>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_persistent_Epilogue, 128x64x64_2x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + using ElementC = cutlass::half_t; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_64,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using PreSwizzleLayout = Layout,_64>,Stride,_64>>; + using TileShapeS2R = Shape<_128,_8>; + + using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, + Copy_Atom>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_Epilogue, 64x128x64_2x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + using ElementC = cutlass::half_t; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_64,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using PreSwizzleLayout = Layout>,Stride<_64,Stride<_1,_4096>>>; + using TileShapeS2R = Shape<_8,_128>; + + using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, + Copy_Atom>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_Epilogue, 128x64x64_2x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + using ElementC = cutlass::half_t; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_64,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + + using PreSwizzleLayout = Layout,Stride<_64,_1>>; + using TileShapeS2R = Shape<_16,_64>; + + using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, + Copy_Atom, + TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, + Copy_Atom>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedPersistent + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu new file mode 100644 index 0000000000..b4edaf61ca --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu @@ -0,0 +1,86 @@ +/*************************************************************************************************** + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/default_transposed_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f32t_f32n_f32n_tensor_op_gmma_f32, 64x128x32_1x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + float, LayoutA, 4, + float, LayoutB, 4, + float, + Shape<_64,_128,_128>, Shape<_1,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu new file mode 100644 index 0000000000..5d30e9614a --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu @@ -0,0 +1,152 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 64x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 8, + int8_t, LayoutB, 8, + int32_t, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32, 128x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 16, + int8_t, LayoutB, 16, + int32_t, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32, 128x64x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 4, + int8_t, LayoutB, 4, + int32_t, + Shape<_128,_64,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu new file mode 100644 index 0000000000..f0762a9dce --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu @@ -0,0 +1,243 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 64x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 16, + int8_t, LayoutB, 16, + int32_t, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 64x128x128_1x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 16, + int8_t, LayoutB, 16, + int32_t, + Shape<_64,_128,_128>, Shape<_1,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 16, + int8_t, LayoutB, 16, + int32_t, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128_1x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 16, + int8_t, LayoutB, 16, + int32_t, + Shape<_128,_128,_128>, Shape<_1,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128_2x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 16, + int8_t, LayoutB, 16, + int32_t, + Shape<_128,_128,_128>, Shape<_2,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 16, + int8_t, LayoutB, 16, + int32_t, + Shape<_128,_128,_128>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu new file mode 100644 index 0000000000..e95772f3b6 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu @@ -0,0 +1,151 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32, 64x128x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + tfloat32_t, LayoutA, 4, + tfloat32_t, LayoutB, 4, + float, + Shape<_64,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32, 64x64x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::tfloat32_t, LayoutA, 2, + cutlass::tfloat32_t, LayoutB, 2, + float, + Shape<_64,_64,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align1_tensor_op_gmma_f32, 128x64x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::tfloat32_t, LayoutA, 1, + cutlass::tfloat32_t, LayoutB, 1, + float, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu new file mode 100644 index 0000000000..ce570a2f12 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu @@ -0,0 +1,185 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_tensor_op_gmma_f32, 64x128x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::tfloat32_t, LayoutA, 4, + cutlass::tfloat32_t, LayoutB, 4, + float, + Shape<_64,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32n_tf32n_f32n_tensor_op_gmma_f32, 64x128x32) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::tfloat32_t, LayoutA, 1, + cutlass::tfloat32_t, LayoutB, 4, + float, + Shape<_64,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32n_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::tfloat32_t, LayoutA, 1, + cutlass::tfloat32_t, LayoutB, 1, + float, + Shape<_64,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32t_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::tfloat32_t, LayoutA, 4, + cutlass::tfloat32_t, LayoutB, 1, + float, + Shape<_64,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu index 09ec485dfb..a13f744a64 100644 --- a/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu @@ -48,7 +48,7 @@ #include "testbed_symm_universal.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -130,4 +130,4 @@ TEST(SM90_Device_Symm_cf64n_cf64n_rs_u_tensor_op_f64, 64x64x16_32x32x16) { ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu index bd3f99e49a..1feb2d67bb 100644 --- a/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu @@ -47,7 +47,7 @@ #include "testbed_symm_universal.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -132,4 +132,4 @@ TEST(SM90_Device_Symm_f64t_f64t_ls_l_tensor_op_f64, 128x128x16_32x64x16) { } ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu index f33a01aafb..76d19f650d 100644 --- a/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu @@ -47,7 +47,7 @@ #include "testbed_rank2k_universal.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -147,4 +147,4 @@ TEST(SM90_Device_Syr2k_cf64n_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu index efd56c0f02..f7aa84db3b 100644 --- a/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu @@ -47,7 +47,7 @@ #include "testbed_rank2k_universal.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -131,4 +131,4 @@ TEST(SM90_Device_Syr2k_f64t_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu index 5e30d4298d..98da67d310 100644 --- a/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu @@ -47,7 +47,7 @@ #include "testbed_rank_k_universal.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -133,4 +133,4 @@ TEST(SM90_Device_Syrk_cf64n_cf64t_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu index a6867e88d4..8fe762775d 100644 --- a/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu @@ -47,7 +47,7 @@ #include "testbed_rank_k_universal.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -123,4 +123,4 @@ TEST(SM90_Device_Syrk_f64t_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/device/testbed.h b/test/unit/gemm/device/testbed.h index 43cf25a6e6..dc21f41fe9 100644 --- a/test/unit/gemm/device/testbed.h +++ b/test/unit/gemm/device/testbed.h @@ -65,6 +65,9 @@ namespace device { template struct Testbed { + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; using ElementAccumulator = typename Gemm::ElementAccumulator; using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; diff --git a/test/unit/gemm/device/testbed_complex.h b/test/unit/gemm/device/testbed_complex.h index 77d5be8518..244bc0682e 100644 --- a/test/unit/gemm/device/testbed_complex.h +++ b/test/unit/gemm/device/testbed_complex.h @@ -63,6 +63,9 @@ template struct TestbedComplex : public Testbed { using Base = Testbed; + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; using ElementAccumulator = typename Gemm::ElementAccumulator; using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; @@ -131,7 +134,7 @@ struct TestbedComplex : public Testbed { if (properties.sharedMemPerBlockOptin < smem_size) { return false; } - + return true; } diff --git a/test/unit/gemm/device/testbed_gemm_with_broadcast.h b/test/unit/gemm/device/testbed_gemm_with_broadcast.h index a0939e0a93..10d5d3f0f2 100644 --- a/test/unit/gemm/device/testbed_gemm_with_broadcast.h +++ b/test/unit/gemm/device/testbed_gemm_with_broadcast.h @@ -100,6 +100,8 @@ template < > struct TestbedGemmWithBroadcast { + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; using ElementC = typename Gemm::ElementC; using ElementAccumulator = typename Gemm::ElementAccumulator; diff --git a/test/unit/gemm/device/testbed_gemm_with_reduction.h b/test/unit/gemm/device/testbed_gemm_with_reduction.h index 07c1c11cfe..6f220b1eb1 100644 --- a/test/unit/gemm/device/testbed_gemm_with_reduction.h +++ b/test/unit/gemm/device/testbed_gemm_with_reduction.h @@ -61,6 +61,7 @@ namespace device { template struct GemmWithReductionReference { + using ElementAccumulator = typename Gemm::ElementAccumulator; using ElementCompute = typename Gemm::GemmKernel::Epilogue::ElementCompute; using ElementC = typename Gemm::ElementC; @@ -93,6 +94,9 @@ template < > struct TestbedGemmWithReduction { + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; using ElementAccumulator = typename Gemm::ElementAccumulator; using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; diff --git a/test/unit/gemm/device/testbed_interleaved.h b/test/unit/gemm/device/testbed_interleaved.h index 57e9f0122f..b54a4b6b8e 100644 --- a/test/unit/gemm/device/testbed_interleaved.h +++ b/test/unit/gemm/device/testbed_interleaved.h @@ -57,6 +57,9 @@ namespace device { template struct InterleavedTestbed { + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; using ElementAccumulator = typename Gemm::ElementAccumulator; using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; diff --git a/test/unit/gemm/device/testbed_rank2k_universal.h b/test/unit/gemm/device/testbed_rank2k_universal.h index ecc1e4b1a9..29f398964a 100644 --- a/test/unit/gemm/device/testbed_rank2k_universal.h +++ b/test/unit/gemm/device/testbed_rank2k_universal.h @@ -64,6 +64,9 @@ namespace device { template struct TestbedRank2KUniversal { + using ElementA = typename Rank2K::ElementA; + using ElementB = typename Rank2K::ElementB; + using ElementC = typename Rank2K::ElementC; using ElementAccumulator = typename Rank2K::ElementAccumulator; using ElementCompute = typename Rank2K::Rank2Kkernel::Epilogue::OutputOp::ElementCompute; @@ -301,7 +304,6 @@ struct TestbedRank2KUniversal { if (properties.sharedMemPerBlockOptin < smem_size) { return false; } - return true; } diff --git a/test/unit/gemm/device/testbed_rank_k_universal.h b/test/unit/gemm/device/testbed_rank_k_universal.h index 6e0fa5db0d..7c403ad8b5 100644 --- a/test/unit/gemm/device/testbed_rank_k_universal.h +++ b/test/unit/gemm/device/testbed_rank_k_universal.h @@ -63,6 +63,8 @@ namespace device { template struct TestbedRank2KUniversal { + using ElementA = typename RankK::ElementA; + using ElementC = typename RankK::ElementC; using ElementAccumulator = typename RankK::ElementAccumulator; using ElementCompute = typename RankK::RankKkernel::Epilogue::OutputOp::ElementCompute; diff --git a/test/unit/gemm/device/testbed_sparse.h b/test/unit/gemm/device/testbed_sparse.h index fd2ab20aba..56f3e5ee8e 100644 --- a/test/unit/gemm/device/testbed_sparse.h +++ b/test/unit/gemm/device/testbed_sparse.h @@ -64,6 +64,9 @@ namespace device { template struct SparseTestbed { + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; using ElementAccumulator = typename Gemm::ElementAccumulator; using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; diff --git a/test/unit/gemm/device/testbed_symm_universal.h b/test/unit/gemm/device/testbed_symm_universal.h index 79873e04d4..1050a2edcc 100644 --- a/test/unit/gemm/device/testbed_symm_universal.h +++ b/test/unit/gemm/device/testbed_symm_universal.h @@ -64,6 +64,9 @@ namespace device { template struct TestbedSymmUniversal { + using ElementA = typename Symm::ElementA; + using ElementB = typename Symm::ElementB; + using ElementC = typename Symm::ElementC; using ElementAccumulator = typename Symm::ElementAccumulator; using ElementCompute = typename Symm::SymmKernel::Epilogue::OutputOp::ElementCompute; diff --git a/test/unit/gemm/device/testbed_trmm_universal.h b/test/unit/gemm/device/testbed_trmm_universal.h index 6a8bc2602f..db40eff767 100644 --- a/test/unit/gemm/device/testbed_trmm_universal.h +++ b/test/unit/gemm/device/testbed_trmm_universal.h @@ -66,6 +66,9 @@ namespace device { template struct TestbedTrmmUniversal { + using ElementA = typename Trmm::ElementA; + using ElementB = typename Trmm::ElementB; + using ElementC = typename Trmm::ElementC; using ElementAccumulator = typename Trmm::ElementAccumulator; using ElementCompute = typename Trmm::TrmmKernel::Epilogue::OutputOp::ElementCompute; diff --git a/test/unit/gemm/device/testbed_universal.h b/test/unit/gemm/device/testbed_universal.h index eb12d856e9..615e9c5c6b 100644 --- a/test/unit/gemm/device/testbed_universal.h +++ b/test/unit/gemm/device/testbed_universal.h @@ -61,6 +61,9 @@ namespace device { template struct TestbedUniversal { + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; using ElementAccumulator = typename Gemm::ElementAccumulator; using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; diff --git a/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu index c6b34e9553..437bed55b4 100644 --- a/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu @@ -48,7 +48,7 @@ #include "testbed_trmm_universal.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -134,4 +134,4 @@ TEST(SM90_Device_Trmm_cf64h_cf64n_cf64t_ls_u_nu_tensor_op_f64, 64x64x16_32x32x16 ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu index 55198183e2..5339bc556b 100644 --- a/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu @@ -48,7 +48,7 @@ #include "testbed_trmm_universal.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -124,4 +124,4 @@ TEST(SM90_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 64x64x16_32x32x16) { ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h b/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h index 03ff561c47..6e14745eb5 100644 --- a/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h +++ b/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h @@ -241,8 +241,6 @@ struct SparseTestbed { // Determine SMEM requirements and waive if not satisfied // - int smem_size = int(sizeof(typename Mma::SharedStorage)); - cudaDeviceProp properties; int device_idx; cudaError_t result = cudaGetDevice(&device_idx); @@ -257,10 +255,6 @@ struct SparseTestbed { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - return true; } @@ -415,7 +409,12 @@ struct SparseTestbed { bool passed = cutlass::reference::host::TensorEquals( matrix_C_computed.host_view(), matrix_C_reference.host_view()); - EXPECT_TRUE(passed) + EXPECT_TRUE(passed); + + if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + + std::cout + << __FILE__ << ":" << __LINE__ << " " << "A:\n" << matrix_A.host_view() << "\n" << "B:\n" << matrix_B.host_view() << "\n" << "E:\n" << matrix_E.host_view() << "\n" @@ -423,6 +422,7 @@ struct SparseTestbed { << matrix_C_reference.host_view() << "\n" << "Computed:\n" << matrix_C_computed.host_view() << "\n"; + } EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); diff --git a/test/unit/gemm/threadblock/mma_multistage_testbed.h b/test/unit/gemm/threadblock/mma_multistage_testbed.h index 946dadfc77..1e859b6184 100644 --- a/test/unit/gemm/threadblock/mma_multistage_testbed.h +++ b/test/unit/gemm/threadblock/mma_multistage_testbed.h @@ -193,11 +193,40 @@ struct Testbed { matrix_C_reference.reset(cutlass::make_Coord(m, n), false); } + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + return true; + } + /// Runs the test bool run( dim3 grid, dim3 block, cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + // // initialize device memory // @@ -318,13 +347,18 @@ struct Testbed { bool passed = cutlass::reference::host::TensorEquals( matrix_C_computed.host_view(), matrix_C_reference.host_view()); - EXPECT_TRUE(passed) + EXPECT_TRUE(passed); + + if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cout + << __FILE__ << ":" << __LINE__ << " " << "A:\n" << matrix_A.host_view() << "\n" << "B:\n" << matrix_B.host_view() << "\n" << "Reference:\n" << matrix_C_reference.host_view() << "\n" << "Computed:\n" << matrix_C_computed.host_view() << "\n"; + } EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); diff --git a/test/unit/gemm/threadblock/mma_pipelined_testbed.h b/test/unit/gemm/threadblock/mma_pipelined_testbed.h index 6eec564609..6f36b53e71 100644 --- a/test/unit/gemm/threadblock/mma_pipelined_testbed.h +++ b/test/unit/gemm/threadblock/mma_pipelined_testbed.h @@ -217,11 +217,25 @@ struct Testbed { matrix_C_reference.reset(cutlass::make_Coord(m, n), false); } + bool sufficient() { + return true; + } + /// Runs the test bool run( dim3 grid, dim3 block, cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // // initialize device memory // @@ -300,7 +314,7 @@ struct Testbed { cudaError_t result = cudaDeviceSynchronize(); EXPECT_EQ(result, cudaSuccess) - << " kernel error: " << cudaGetErrorString(result); + << " kernel error: " << cudaGetErrorString(result) << " on device " << GetCudaDevice(); matrix_C_computed.sync_host(); @@ -316,7 +330,7 @@ struct Testbed { bool passed = cutlass::reference::host::TensorEquals( matrix_C_computed.host_view(), matrix_C_reference.host_view()); - EXPECT_TRUE(passed); + EXPECT_TRUE(passed) << "Failed on device " << GetCudaDevice(); if (!passed) { std::ofstream output("mma_pipelined_testbed_errors.txt"); diff --git a/test/unit/gemm/warp/gemm_complex_sm90.cu b/test/unit/gemm/warp/gemm_complex_sm90.cu index c30e414734..38bdfa65d8 100644 --- a/test/unit/gemm/warp/gemm_complex_sm90.cu +++ b/test/unit/gemm/warp/gemm_complex_sm90.cu @@ -50,7 +50,7 @@ #include "testbed.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x8x4_16x8x4_nt) { @@ -331,4 +331,4 @@ TEST(SM90_warp_gemm_complex_tensor_op_f64, 64x64x4_16x8x4_tn) { test::gemm::warp::TestbedComplex().run(); } -#endif // if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/warp/gemm_sm90.cu b/test/unit/gemm/warp/gemm_sm90.cu index ebb7d91a7c..f417a41fcc 100644 --- a/test/unit/gemm/warp/gemm_sm90.cu +++ b/test/unit/gemm/warp/gemm_sm90.cu @@ -50,7 +50,7 @@ #include "testbed.h" -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) TEST(SM90_warp_gemm_tensor_op_congruous_f64, 16x16x4_16x16x4_16x8x4) { using Shape = cutlass::gemm::GemmShape<16, 16, 4>; @@ -203,4 +203,4 @@ TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 32x64x16_32x64x16_16x8x4) { } //////////////////////////////////////////////////////////////////////////////// -#endif // if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/warp/testbed.h b/test/unit/gemm/warp/testbed.h index 6d7e143f57..3487aa0ffd 100644 --- a/test/unit/gemm/warp/testbed.h +++ b/test/unit/gemm/warp/testbed.h @@ -191,10 +191,47 @@ struct Testbed { tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); } + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test bool run( cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + // // initialize device memory // @@ -401,10 +438,46 @@ struct TestbedComplex { tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); } + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + /// Runs the test bool run( cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + // // initialize device memory // @@ -676,10 +749,46 @@ struct TransformTestbed { tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); } + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + /// Runs the test bool run( cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + // // initialize device memory // @@ -878,10 +987,46 @@ struct TransformedTestbedComplex { tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); } + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + /// Runs the test bool run( cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + // // initialize device memory // @@ -1199,12 +1344,47 @@ struct SparseTestbed { Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); } + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + /// Runs the test bool run( cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { + if (!sufficient()) { + return true; + } + // // initialize device memory // diff --git a/test/unit/pipeline/CMakeLists.txt b/test/unit/pipeline/CMakeLists.txt new file mode 100644 index 0000000000..fb38dc7285 --- /dev/null +++ b/test/unit/pipeline/CMakeLists.txt @@ -0,0 +1,36 @@ +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_test_unit_add_executable( + cutlass_test_unit_pipeline + pipeline_tma_async.cu + pipeline_tma_async_warp_specialized.cu + pipeline_tma_async_warp_specialized_persistent.cu + pipeline_async.cu + sequence_barrier.cu +) diff --git a/test/unit/pipeline/pipeline_async.cu b/test/unit/pipeline/pipeline_async.cu new file mode 100644 index 0000000000..d2adad6a30 --- /dev/null +++ b/test/unit/pipeline/pipeline_async.cu @@ -0,0 +1,468 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit test for the PipelineAsync class +*/ + +#define KERNEL_DBG_TRACE false + +#include "../common/cutlass_unit_test.h" +#include +#include + +#include +#include + +#include +#include + +#include "cutlass/core_io.h" + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include "testbed.h" +#include "cutlass/pipeline.hpp" +#include "cutlass/arch/barrier.h" +#include "cute/arch/cluster_sm90.hpp" + +using namespace cute; + +//////////////////// KERNEL ///////////////////////// + +template +struct SharedStorage +{ + typename cutlass::PipelineAsync::SharedStorage storage; +}; + +// Goal of this kernel is to complete deadlock-free +// Simple 1 producer warp, one consumer warp scenario +template +__global__ static +void pipeline_async_basic_device(uint32_t const num_iterations) +{ + + extern __shared__ char shared_memory[]; + using MainloopPipeline = typename cutlass::PipelineAsync; + using PipelineState = typename cutlass::PipelineState; + + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + + auto cta_layout = Layout{}; // (m,n) -> cta_id + + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_predicate = cute::elect_one_sync(); + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + auto cluster_shape = ClusterShape{}; + + // This example showcases 2 producer 1 consumer example + typename MainloopPipeline::Params params; + params.producer_arv_count = 2; + params.consumer_arv_count = 1; + MainloopPipeline pipeline(shared_storage.storage, params); + + // Ensure All CTAs in Cluster have completed init before issuing commits + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + __syncthreads(); + + if (lane_predicate) { + // Producer Warps + if (warp_idx==0 || warp_idx==1) { + + int prologue_iterations = min(NumStages, num_iterations); + for ( int i = 0; i < prologue_iterations; ++i) { + // Can also specify stage to commit directly + pipeline.producer_commit(i); + } + + int mainloop_iterations = num_iterations - prologue_iterations; + + // Only the mainloop needs a PipelineState because this is where we start "waiting" (acquiring) + PipelineState smem_pipe_write; + + for ( ; mainloop_iterations > 0; --mainloop_iterations) { + pipeline.producer_acquire(smem_pipe_write); + pipeline.producer_commit(smem_pipe_write); + ++smem_pipe_write; + } + } + else { + PipelineState smem_pipe_read; + for (int iter=0 ; iter < num_iterations; ++iter) { + pipeline.consumer_wait(smem_pipe_read); + pipeline.consumer_release(smem_pipe_read.index()); + ++smem_pipe_read; + } + } + } + + // To make sure remote SMEM doesn't get destroyed + cute::cluster_arrive(); + cute::cluster_wait(); +} +///////////////////////////////////////////////////// + +template +struct PipelineTest { + + // + // Data members + // + static constexpr uint32_t Stages = Stages_; + static constexpr uint32_t kBlockSize = 96; + using ClusterShape = ClusterShape_; + + // + // Methods + // + + // Ctor + PipelineTest() = default; + + + // Run CuTe GEMM kernel + cudaError_t run(uint32_t const kNumIters, + cudaStream_t stream = nullptr) { + + // Pipeline (multistage pipeline) + auto num_stages = Int{}; + + auto cluster_shape = Shape, Int, _1>{}; + + // + // Configure and launch + // + int iterations = 2; + cudaError_t result; + + for (int iter = 0; iter < iterations; ++iter) { + + // Define the tiled MMA layout (static, 4warps) + using MainloopPipeline = typename cutlass::PipelineAsync; + + int smem_size = int(sizeof(SharedStorage)); + + result = cudaFuncSetAttribute( + pipeline_async_basic_device, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + // Launch a single Cluster, with 128 thread per CTA + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); + dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); + dim3 dimBlock(kBlockSize,1,1); + + const void* kernel = (const void*)pipeline_async_basic_device; + int iters = kNumIters; + void* kernel_params[] = {reinterpret_cast(&iters)}; + cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); + + } // profiling loop ends + + result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; + return result; + } + + return cudaSuccess; + } + +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED +TEST(SM90_Verify_PipelineAsync, Cluster1x1_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster1x1_Stage5) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr uint32_t Stages = 5; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster1x1_Stage10) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr uint32_t Stages = 10; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster2x2_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster2x2_Stage5) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; + static constexpr uint32_t Stages = 5; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster2x2_Stage10) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; + static constexpr uint32_t Stages = 10; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster1x2_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster1x2_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster1x2_Stage10) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; + static constexpr uint32_t Stages = 10; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster2x1_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster2x1_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster4x1_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster4x1_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster1x4_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster1x4_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster2x4_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster2x4_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster4x2_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster4x2_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage3) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 3; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage4) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 4; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage5) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 5; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage6) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 6; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage8) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 8; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage9) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 9; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage10) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 10; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage11) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 11; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} +#endif diff --git a/test/unit/pipeline/pipeline_tma_async.cu b/test/unit/pipeline/pipeline_tma_async.cu new file mode 100644 index 0000000000..90e0ca3a02 --- /dev/null +++ b/test/unit/pipeline/pipeline_tma_async.cu @@ -0,0 +1,469 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit test for the PipelineTmaAsync class +*/ + + +#define KERNEL_DBG_TRACE false + +#include "../common/cutlass_unit_test.h" +#include +#include + +#include +#include + +#include +#include + +#include "cutlass/core_io.h" + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include "testbed.h" +#include "cutlass/pipeline.hpp" +#include "cutlass/arch/barrier.h" +#include "cute/arch/cluster_sm90.hpp" + +using namespace cute; + +//////////////////// KERNEL ///////////////////////// + +template +struct SharedStorage +{ + typename cutlass::PipelineTmaAsync::SharedStorage storage; +}; + +// Goal of this kernel is to complete deadlock-free +template +__global__ static +void pipeline_device(uint32_t const NumIterations) +{ + + extern __shared__ char shared_memory[]; + using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmma; + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + auto cta_layout = Layout{}; // (m,n) -> cta_id + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_group_thread_idx = threadIdx.x % 128; + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + + auto cluster_shape = ClusterShape{}; + + // #Producers = #RowsInCluster + #ColsInCluster - 1 + uint32_t const NumProducers = cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1; + uint32_t const TmaTransactionBytes = sizeof(uint32_t) * NumProducers; + uint32_t const per_cta_bytes = sizeof(uint32_t); + + // mbarrier.init + typename MainloopPipeline::Params params; + params.transaction_bytes = TmaTransactionBytes; + params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; + params.is_leader = warp_group_thread_idx == 0; + params.num_consumers = 128; + + MainloopPipeline pipeline(shared_storage.storage, params); + + __syncthreads(); + + // Ensure All CTAs in Cluster have completed init before issuing commits + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + + // Total number of gemm_k_iterations + auto mma_k_iterations = NumIterations; + auto tma_k_iterations = NumIterations; + + PipelineState smem_pipe_read; + // For the DMA (prologue) - we start with an opposite phase - since we skip all waits + // i.e., we know that the buffer is indeed empty + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + PipelineState smem_pipe_release; + int K_TILE_MMAS = 1; + + int lane_predicate = cute::elect_one_sync(); + int k_pipe_tma_prologue = min(NumStages, tma_k_iterations); + + // DMA Prologue (Loads) + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < k_pipe_tma_prologue; ++i) { + pipeline.producer_acquire(smem_pipe_write); + // cp.async.bulk.tensor would typically happen here + pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); + ++smem_pipe_write; + } + tma_k_iterations -= k_pipe_tma_prologue; + + // MMA Prologue (Compute) - modeling inflight MMAs + for (int iter = 0; iter < K_TILE_MMAS; ++iter) + { + pipeline.consumer_wait(smem_pipe_read); + warpgroup_arrive(); + // GMMA would typically happen here + + ++smem_pipe_read; + } + + mma_k_iterations -= K_TILE_MMAS; + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < mma_k_iterations; ++iter) + { + pipeline.consumer_wait(smem_pipe_read); + + warpgroup_arrive(); + // GMMA would typically happen here + + pipeline.consumer_release(smem_pipe_release); + + if (lane_predicate && (warp_idx == 0) && (tma_k_iterations > 0)) { + pipeline.producer_acquire(smem_pipe_write); + // cp.async.bulk.tensor would typically happen here + pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); + ++smem_pipe_write; + --tma_k_iterations; + } + + // next read stage + ++smem_pipe_read; + ++smem_pipe_release; + } + + // To make sure remote SMEM doesn't get destoryed + cute::cluster_arrive(); + cute::cluster_wait(); +} +///////////////////////////////////////////////////// + +/// Device NT GMMA + TMA specialized +template +struct PipelineTest { + + // + // Data members + // + static constexpr uint32_t Stages = Stages_; + static constexpr uint32_t kBlockSize = 128; + using ClusterShape = ClusterShape_; + + // + // Methods + // + + // Ctor + PipelineTest(){}; + + + // Run CuTe GEMM kernel + cudaError_t run(uint32_t const kNumIters, + cudaStream_t stream = 0) { + + float elapsed_ms = 0.0f; + // Pipeline (multistage pipeline) + auto num_stages = Int{}; + + auto cluster_shape = Shape, Int, _1>{}; + + // + // Configure and launch + // + int iterations = 1; + cudaEvent_t events[2]; + cudaError_t result; + + for (cudaEvent_t & event : events) { + result = cudaEventCreate(&event); + if (result != cudaSuccess) { + std::cerr << "Error: Failed to create event."; + return result; + } + } + + result = cudaEventRecord(events[0]); + + if (result != cudaSuccess) { + std::cerr << "Error: Failed to record start event."; + return result; + } + + for (int iter = 0; iter < iterations; ++iter) { + + // Define the tiled MMA layout (static, 4warps) + using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmma; + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + + int smem_size = int(sizeof(SharedStorage)); + + result = cudaFuncSetAttribute( + pipeline_device, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + // Launch a single Cluster, with 128 thread per CTA + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); + dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); + dim3 dimBlock(kBlockSize,1,1); + + const void* kernel = (const void*)pipeline_device; + int iters = kNumIters; + void* kernel_params[] = {reinterpret_cast(&iters)}; + cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); + + } // profiling loop ends + + result = cudaEventRecord(events[1]); + + if (result != cudaSuccess) { + std::cerr << "Error: Failed to record stop event."; + return result; + } + + result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; + return result; + } + + result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); + + if (result != cudaSuccess) { + std::cerr << "Failed to create event."; + return result; + } + + for (cudaEvent_t & event : events) { + (void)cudaEventDestroy(event); + } + + return cudaSuccess; + } +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED +TEST(SM90_Verify_PipelineTmaAsync, Cluster1x1_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster1x1_Stage5) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr uint32_t Stages = 5; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster1x1_Stage10) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr uint32_t Stages = 10; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster2x2_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster2x2_Stage5) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; + static constexpr uint32_t Stages = 5; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster2x2_Stage10) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; + static constexpr uint32_t Stages = 10; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster4x4_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster4x4_Stage10) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 10; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster1x2_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster1x2_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster1x2_Stage10) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; + static constexpr uint32_t Stages = 10; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster2x1_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster2x1_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster4x1_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster4x1_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster1x4_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster1x4_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster2x4_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster2x4_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster4x2_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync, Cluster4x2_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} +#endif diff --git a/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu b/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu new file mode 100644 index 0000000000..f0d6a79c55 --- /dev/null +++ b/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu @@ -0,0 +1,525 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit test for the PipelineTmaAsync class as it would be used in a Warp specialized loop +*/ + +#define KERNEL_DBG_TRACE false + +#include "../common/cutlass_unit_test.h" +#include +#include + +#include +#include + +#include +#include + +#include "cutlass/core_io.h" +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include "testbed.h" +#include "cutlass/pipeline.hpp" +#include "cutlass/arch/barrier.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" + + +using namespace cute; +using namespace cutlass; + +//////////////////// KERNEL ///////////////////////// + +template +struct SharedStorage +{ + typename cutlass::PipelineTmaAsync::SharedStorage storage ; +}; + +struct KernelParams +{ + uint32_t num_iterations; + int* data_ptr; +}; + +// Goal of this kernel is to complete deadlock-free +template +__launch_bounds__(384, 1) +__global__ static +void pipeline_device(KernelParams const kernel_params) +{ + extern __shared__ char shared_memory[]; + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + using PipelineState = typename cutlass::PipelineState; + + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + auto cta_layout = Layout{}; // (m,n) -> cta_id + int warp_group_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + int warp_group_thread_idx = threadIdx.x % 128; + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + + auto cluster_shape = ClusterShape{}; + + // #Producers = #RowsInCluster + #ColsInCluster - 1 + uint32_t const NumProducers = cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1; + uint32_t const TmaTransactionBytes = static_cast(sizeof(uint32_t) * NumProducers); + uint32_t const per_cta_bytes = sizeof(uint32_t); + + // mbarrier.init + typename MainloopPipeline::Params params; + params.transaction_bytes = TmaTransactionBytes; + if (warp_group_idx == 0) { + params.role = MainloopPipeline::ThreadCategory::Producer; + } + else { + params.role = MainloopPipeline::ThreadCategory::Consumer; + } + params.is_leader = warp_group_thread_idx == 0; + params.num_consumers = 128; + + MainloopPipeline pipeline(shared_storage.storage, params); + + __syncthreads(); + + // Ensure All CTAs in Cluster have completed init before issuing commits + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + + + // Producer WarpGroup + if (warp_group_idx == 0) { + cutlass::arch::warpgroup_reg_alloc<232>(); + + int lane_predicate = cute::elect_one_sync(); + if (warp_idx_in_warpgroup == 0 && lane_predicate) { + + int tma_k_prologue = min(Stages, kernel_params.num_iterations); + + // Simulating Prologue TMA Loads + // For the DMA (prologue) - we start with an opposite phase - since we skip all waits + // i.e., we know that the buffer is indeed empty + PipelineState smem_pipe_write = make_producer_start_state(); + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < tma_k_prologue; ++i) { + pipeline.producer_acquire(smem_pipe_write); + // Simulating cp.async.bulk.tensor behavior + pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); + ++smem_pipe_write; + } + int tma_k_iter = kernel_params.num_iterations - tma_k_prologue; + + // Simulating Mainloop TMA Loads + CUTE_NO_UNROLL + for ( ; tma_k_iter > 0; --tma_k_iter) { + + pipeline.producer_acquire(smem_pipe_write); + + // Simulating cp.async.bulk.tensor behavior + pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); + + // Advance write stage + ++smem_pipe_write; + } + + // Tail Loop + // Handles the case where we never enter the mainloop + PipelineState tail = tma_k_prologue == Stages ? smem_pipe_write : PipelineState{}; + for ( int i = 0; i < tma_k_prologue; ++i) { + pipeline.producer_acquire(tail); + ++tail; + } + } + // Consumer WarpGroup + } else if(warp_group_idx == 1) { + cutlass::arch::warpgroup_reg_alloc<232>(); + + PipelineState smem_pipe_read; + PipelineState smem_pipe_release; + + // simulates accumulators + extra reg. pressure + int arr[168]; + + // Init Shared Memory read stages & PhaseBit + static constexpr uint32_t K_PIPE_MMAS = 1; + static_assert( K_PIPE_MMAS < Stages, "ERROR : Too many MMAs in flight"); + + // Total number of gemm iterations + auto gemm_k_iterations = kernel_params.num_iterations; + + // Simulating Prologue MMAs + int mma_k_prologue = min(K_PIPE_MMAS, gemm_k_iterations); + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < mma_k_prologue; ++iter) { + pipeline.consumer_wait(smem_pipe_read); + + warpgroup_arrive(); + // GMMA would typically happen here + + ++smem_pipe_read; + } + gemm_k_iterations -= mma_k_prologue; + + // Simulating Mainloop MMAs + CUTLASS_PRAGMA_NO_UNROLL + for ( ; gemm_k_iterations > 0; --gemm_k_iterations) { + + /// Wait on the smem_pipe_read stage / phase + pipeline.consumer_wait(smem_pipe_read); + + warpgroup_arrive(); + // GMMA would typically happen here + + // Dummy op - which will never happen + // But simulates high register usage. + CUTE_UNROLL + for(int i = 0; i < 168; ++i){ + if (threadIdx.x > 256){ + arr[i] += kernel_params.data_ptr[i]; + } + } + + pipeline.consumer_release(smem_pipe_release); + + // Advance stages + ++smem_pipe_read; + ++smem_pipe_release; + } + + // Dummy op - which will never happen + CUTE_UNROLL + for(int i = 0; i < 168; ++i){ + if (threadIdx.x > 256){ + kernel_params.data_ptr[i] = arr[i]; + } + } + + // Tail Loop + for (int i = 0; i < K_PIPE_MMAS; ++i){ + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + } + + // Warp-Group #2 + } else { + cutlass::arch::warpgroup_reg_dealloc<40>(); + } +} +///////////////////////////////////////////////////// + +/// Device NT GMMA + TMA specialized +template +struct PipelineTest { + + // + // Data members + // + static constexpr uint32_t Stages = Stages_; + static constexpr uint32_t kBlockSize = 128 * 3; + using ClusterShape = ClusterShape_; + + // + // Methods + // + + // Ctor + PipelineTest(){}; + + // Run CuTe GEMM kernel + cudaError_t run(uint32_t const kNumIters, + cudaStream_t stream = 0) { + + float elapsed_ms = 0.0f; + // Pipeline (multistage pipeline) + auto num_stages = Int{}; + auto cluster_shape = Shape, Int, _1>{}; + + // + // Configure and launch + // + int iterations = 1; + cudaEvent_t events[2]; + cudaError_t result; + + for (cudaEvent_t & event : events) { + result = cudaEventCreate(&event); + if (result != cudaSuccess) { + std::cerr << "Error: Failed to create event."; + return result; + } + } + + result = cudaEventRecord(events[0]); + + if (result != cudaSuccess) { + std::cerr << "Error: Failed to record start event."; + return result; + } + + for (int iter = 0; iter < iterations; ++iter) { + + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + + int smem_size = int(sizeof(SharedStorage)); + + result = cudaFuncSetAttribute( + pipeline_device, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + // Launch a single Cluster, with kBlockSize threads per CTA + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); + dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); + dim3 dimBlock(kBlockSize,1,1); + + const void* kernel = (const void*)pipeline_device; + KernelParams params{kNumIters, nullptr}; + void* kernel_params[] = {reinterpret_cast(¶ms)}; + cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); + + } + + result = cudaEventRecord(events[1]); + + if (result != cudaSuccess) { + std::cerr << "Error: Failed to record stop event."; + return result; + } + + result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; + return result; + } + + result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); + + if (result != cudaSuccess) { + std::cerr << "Failed to create event."; + return result; + } + + for (cudaEvent_t & event : events) { + (void)cudaEventDestroy(event); + } + + return cudaSuccess; + } +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x1_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x1_Stage5) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr uint32_t Stages = 5; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x1_Stage10) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr uint32_t Stages = 10; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x2_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x2_Stage5) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; + static constexpr uint32_t Stages = 5; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x2_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x4_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x4_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x1_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x1_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x2_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x2_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x1_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x1_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x4_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x4_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x4_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x4_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x2_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x2_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} +#endif diff --git a/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu b/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu new file mode 100644 index 0000000000..4b6a3b1d26 --- /dev/null +++ b/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu @@ -0,0 +1,585 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit test for the PipelineTmaAsync class used in a WarpSpecialized Persistent loop +*/ + +#define KERNEL_DBG_TRACE false + +#include "../common/cutlass_unit_test.h" +#include +#include + +#include +#include + +#include +#include + +#include "cutlass/core_io.h" +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include "testbed.h" +#include "cutlass/pipeline.hpp" +#include "cutlass/arch/barrier.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" + + +using namespace cute; +using namespace cutlass; + +//////////////////// KERNEL ///////////////////////// + +template +struct SharedStorage +{ + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_storage; + typename PingPongBarrier::SharedStorage pingpong_storage; +}; + +template +struct CollectiveSimulation { + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + using PipelineState = typename cutlass::PipelineState; + + CUTLASS_DEVICE + static void + dma_wg_simulation(MainloopPipeline pipeline, PipelineState tile_start_state_pipe, + uint32_t const num_iterations) { + uint32_t const per_cta_bytes = sizeof(uint32_t); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + int lane_predicate = cute::elect_one_sync(); + if (warp_idx_in_warpgroup==0 && lane_predicate) { + + int tma_k_prologue = min(Stages, num_iterations); + + // Simulating Prologue TMA Loads + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < tma_k_prologue; ++i) { + pipeline.producer_acquire(tile_start_state_pipe); + // Simulating cp.async.bulk.tensor behavior + pipeline.producer_commit(tile_start_state_pipe.index(), per_cta_bytes); + ++tile_start_state_pipe; + } + int tma_k_iter = num_iterations - tma_k_prologue; + + PipelineState wr_pipe = tile_start_state_pipe; + // Simulating Mainloop TMA Loads + CUTE_NO_UNROLL + for ( ; tma_k_iter > 0; --tma_k_iter){ + + pipeline.producer_acquire(wr_pipe); + + // Simulating cp.async.bulk.tensor behavior + pipeline.producer_commit(wr_pipe.index(), per_cta_bytes); + + // Advance write stage + ++wr_pipe; + } + } + } + + CUTLASS_DEVICE + static void + math_wg_simulation(MainloopPipeline pipeline, PipelineState tile_start_state_pipe, + uint32_t const num_iterations, int* data_ptr) { + PipelineState rd_pipe = tile_start_state_pipe; + PipelineState release_pipe = rd_pipe; + + // simulates accumulators + extra reg. pressure + int arr[168]; + + // Init Shared Memory read stages & PhaseBit + static constexpr uint32_t K_PIPE_MMAS = 1; + static_assert( K_PIPE_MMAS < Stages, "ERROR : Too many MMAs in flight"); + + // Total number of gemm iterations + auto gemm_k_iterations = num_iterations; + + // Simulating Prologue MMAs + int mma_k_prologue = min(K_PIPE_MMAS, gemm_k_iterations); + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < mma_k_prologue; ++iter) { + pipeline.consumer_wait(rd_pipe); + + warpgroup_arrive(); + // GMMA would typically happen here + + ++rd_pipe; + } + gemm_k_iterations -= mma_k_prologue; + + // Simulating Mainloop MMAs + CUTLASS_PRAGMA_NO_UNROLL + for ( ; gemm_k_iterations > 0; --gemm_k_iterations) { + + /// Wait on the rd_pipe stage / phase + pipeline.consumer_wait(rd_pipe); + + warpgroup_arrive(); + // GMMA would typically happen here + + // Dummy op - which will never happen + // But simulates high register usage. + CUTE_UNROLL + for(int i = 0; i < 168; ++i){ + if (threadIdx.x > 384){ + arr[i] += data_ptr[i]; + } + } + + pipeline.consumer_release(release_pipe); + + // Advance stages + ++rd_pipe; + ++release_pipe; + } + + // Dummy op - which will never happen + CUTE_UNROLL + for(int i = 0; i < 168; ++i){ + if (threadIdx.x > 384){ + data_ptr[i] = arr[i]; + } + } + + // Tail Loop + for (int i = 0; i < K_PIPE_MMAS; ++i){ + pipeline.consumer_release(release_pipe); + ++release_pipe; + } + + } +}; + +struct KernelParams +{ + uint32_t num_iterations; + int tiles_per_cluster; + int* data_ptr; +}; + +// Goal of this kernel is to complete deadlock-free +template +__launch_bounds__(384, 1) +__global__ static +void pipeline_device(KernelParams params) +{ + extern __shared__ char shared_memory[]; + using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized; + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + using PipelineState = typename cutlass::PipelineState; + + /* One for Mainloop and one for Epilogue */ + constexpr int StagesPerMathWarpGroup = 2; + constexpr int MathWarpGroupCountPersistent = 2; + using PingPongBarrier = typename cutlass::OrderedSequenceBarrier; + + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + auto cta_layout = Layout{}; // (m,n) -> cta_id + int warp_group_idx = __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarpGroup, 0); + int warp_group_thread_idx = threadIdx.x % NumThreadsPerWarpGroup; + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + + auto cluster_shape = ClusterShape{}; + + // #Producers = #RowsInCluster + #ColsInCluster - 1 + uint32_t const NumProducers = cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1; + uint32_t const TmaTransactionBytes = static_cast(sizeof(uint32_t) * NumProducers); + + // mbarrier.init + typename MainloopPipeline::Params pipeline_params; + pipeline_params.transaction_bytes = TmaTransactionBytes; + if (warp_group_idx == 0) { + pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + else { + pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = NumThreadsPerWarpGroup; + + MainloopPipeline pipeline(shared_storage.pipeline_storage, pipeline_params); + PipelineState tile_start_state_pipe; + + int tiles_per_cluster = params.tiles_per_cluster; + + /* Offset pipeline start state for Math WG 2 */ + if (warp_group_idx == 2) { + // Update pipeline state for next persistent tile + tile_start_state_pipe.advance(params.num_iterations); + tiles_per_cluster--; + } + + typename PingPongBarrier::Params pingpong_params; + pingpong_params.group_id = warp_group_idx - 1; // Since DMA Warp Group Idx 0 will not participate + pingpong_params.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group + PingPongBarrier math_wg_barrier(shared_storage.pingpong_storage, pingpong_params); + + __syncthreads(); + + // Ensure All CTAs in Cluster have completed init before issuing commits + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + + // Producer/DMA WarpGroup + if (warp_group_idx == 0) { + cutlass::arch::warpgroup_reg_dealloc<40>(); + // For the DMA (prologue) - we start with an opposite phase - since we skip all waits + // i.e., we know that the buffer is indeed empty + PipelineState tile_prologue_state_pipe = make_producer_start_state(); + while (tiles_per_cluster > 0) { + CollectiveSimulation::dma_wg_simulation(pipeline, tile_prologue_state_pipe, params.num_iterations); + // Update pipeline state for next persistent tile + tile_prologue_state_pipe.advance(params.num_iterations); + tiles_per_cluster--; + } + } + // Math WarpGropups + if(warp_group_idx == 1 || warp_group_idx == 2) { + cutlass::arch::warpgroup_reg_alloc<232>(); + while (tiles_per_cluster > 0) { + // MMA + math_wg_barrier.wait(); + CollectiveSimulation::math_wg_simulation(pipeline, tile_start_state_pipe, params.num_iterations, params.data_ptr); + math_wg_barrier.arrive(); + // Epilogue + math_wg_barrier.wait(); + // Simulates long running stage + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) + __nanosleep(100000); + #endif + math_wg_barrier.arrive(); + // Update pipeline state for next persistent tile + tile_start_state_pipe.advance(params.num_iterations * 2); + tiles_per_cluster -= 2; + } + } + + // Makes sure remote SMEM doesn't get destroyed + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); +} +///////////////////////////////////////////////////// + +/// Device NT GMMA + TMA specialized +template +struct PipelineTest { + + // + // Data members + // + static constexpr uint32_t Stages = Stages_; + static constexpr uint32_t kBlockSize = 128 * 3; + using ClusterShape = ClusterShape_; + + // + // Methods + // + + // Run CuTe GEMM kernel + cudaError_t run(uint32_t const kNumIters, + cudaStream_t stream = 0) { + + float elapsed_ms = 0.0f; + // Pipeline (multistage pipeline) + auto num_stages = Int{}; + auto cluster_shape = Shape, Int, _1>{}; + + // + // Configure and launch + // + int iterations = 1; + cudaEvent_t events[2]; + cudaError_t result; + + for (cudaEvent_t & event : events) { + result = cudaEventCreate(&event); + if (result != cudaSuccess) { + std::cerr << "Error: Failed to create event."; + return result; + } + } + + result = cudaEventRecord(events[0]); + + if (result != cudaSuccess) { + std::cerr << "Error: Failed to record start event."; + return result; + } + + for (int iter = 0; iter < iterations; ++iter) { + + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + + constexpr int StagesPerMathWarpGroup = 2; + constexpr int MathWarpGroupCountPersistent = 2; + int smem_size = int(sizeof(SharedStorage>)); + + result = cudaFuncSetAttribute( + pipeline_device, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + // Launch a single Cluster, with kBlockSize threads per CTA + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); + dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); + dim3 dimBlock(kBlockSize,1,1); + + int tiles_per_cluster = (kNumIters % 10) + 1; + printf("Persistent version: Tiles per Cluster = %d\n", tiles_per_cluster); + + const void* kernel = (const void*)pipeline_device; + KernelParams params{kNumIters, tiles_per_cluster, nullptr}; + void *kernel_params[] = {¶ms}; + cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); + + } + + result = cudaEventRecord(events[1]); + + if (result != cudaSuccess) { + std::cerr << "Error: Failed to record stop event."; + return result; + } + + result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; + return result; + } + + result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); + + if (result != cudaSuccess) { + std::cerr << "Failed to create event."; + return result; + } + + for (cudaEvent_t & event : events) { + (void)cudaEventDestroy(event); + } + + return cudaSuccess; + } +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x1_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x1_Stage5) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr uint32_t Stages = 5; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x1_Stage10) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr uint32_t Stages = 10; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x2_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x2_Stage5) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; + static constexpr uint32_t Stages = 5; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x2_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x4_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x4_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x1_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x1_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x2_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x2_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x1_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x1_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x4_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x4_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x4_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x4_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x2_Stage2) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x2_Stage7) { + Options options; + using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; + static constexpr uint32_t Stages = 7; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} +#endif diff --git a/test/unit/pipeline/sequence_barrier.cu b/test/unit/pipeline/sequence_barrier.cu new file mode 100644 index 0000000000..f426ca0309 --- /dev/null +++ b/test/unit/pipeline/sequence_barrier.cu @@ -0,0 +1,226 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit test for the OrderedSequenceBarrier class +*/ + +#include "../common/cutlass_unit_test.h" +#include +#include + +#include +#include + +#include +#include + +#include "cutlass/core_io.h" + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include "testbed.h" +#include "cutlass/pipeline.hpp" +#include "cutlass/arch/barrier.h" +#include "cute/arch/cluster_sm90.hpp" + +using namespace cute; + +//////////////////// KERNEL ///////////////////////// + +template +struct SharedStorage +{ + typename OrderedSequencer::SharedStorage storage; +}; + +// Goal of this kernel is to complete deadlock-free +template +__global__ static +void ordered_sequence_device(uint32_t const num_iterations) +{ + + extern __shared__ char shared_memory[]; + using SequenceBarrier = typename cutlass::OrderedSequenceBarrier; + using SmemStorage = SharedStorage; + + SmemStorage& shared_storage = *reinterpret_cast(shared_memory); + + int group_idx = threadIdx.x / ThreadsPerGroup; + + typename SequenceBarrier::Params params; + params.group_id = group_idx; // sequence ID + params.group_size = ThreadsPerGroup; // Number of threads / participants in a group + + SequenceBarrier barrier(shared_storage.storage, params); + + // Ensure All CTAs in Cluster have completed init before issuing commits + __syncthreads(); + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int i = 0; i < num_iterations; ++i){ + + barrier.wait(); + // STAGE 1 CODE... + #ifndef NDEBUG + int thread_idx_in_group = threadIdx.x % ThreadsPerGroup; + if (thread_idx_in_group == 0) { + printf("STAGE 0 : Group_IDX : %d, id = %d, iter = %d, tidx = %d\n", group_idx, params.id, i, threadIdx.x); + } + #endif + // Simulates long running stage + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) + __nanosleep(100000); + #endif + barrier.arrive(); + + barrier.wait(); + // STAGE 2 CODE... + #ifndef NDEBUG + if (thread_idx_in_group == 0) { + printf("STAGE 1 : Group_IDX : %d, id = %d, iter = %d, tidx = %d\n", group_idx, params.id, i, threadIdx.x); + } + #endif + // Simulates long running stage + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) + __nanosleep(100000); + #endif + barrier.arrive(); + } + + // To make sure remote SMEM doesn't get destroyed + cute::cluster_arrive(); + cute::cluster_wait(); +} +///////////////////////////////////////////////////// + +template +struct PipelineTest { + + // + // Data members + // + static constexpr uint32_t ThreadsPerGroup = 128; + static constexpr uint32_t BlockSize = GroupCount_ * ThreadsPerGroup; + static constexpr uint32_t Stages = Stages_; + static constexpr uint32_t GroupCount = GroupCount_; + using SequenceBarrier = typename cutlass::OrderedSequenceBarrier; + using SmemStorage = SharedStorage; + + // + // Methods + // + + // Run CuTe GEMM kernel + cudaError_t run(uint32_t const kNumIters, + cudaStream_t stream = nullptr) { + + // Pipeline (multistage pipeline) + auto cluster_shape = Shape<_1, _1, _1>{}; + + // + // Configure and launch + // + int iterations = 1; + cudaError_t result; + + for (int iter = 0; iter < iterations; ++iter) { + + int smem_size = int(sizeof(SmemStorage)); + + result = cudaFuncSetAttribute( + ordered_sequence_device, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + // Launch a single Cluster, with 128 thread per CTA + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); + dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); + dim3 dimBlock(BlockSize,1,1); + + const void* kernel = (const void*)ordered_sequence_device; + int iters = kNumIters; + void* kernel_params[] = {reinterpret_cast(&iters)}; + cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); + + } // profiling loop ends + + result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; + return result; + } + + return cudaSuccess; + } +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED +TEST(SM90_Verify_OrderedSequence, Depth_2_Length_2) { + Options options; + static constexpr uint32_t GroupCount = 2; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_OrderedSequence, Depth_2_Length_3) { + Options options; + static constexpr uint32_t GroupCount = 3; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_OrderedSequence, Depth_2_Length_4) { + Options options; + static constexpr uint32_t GroupCount = 4; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +TEST(SM90_Verify_OrderedSequence, Depth_2_Length_5) { + Options options; + static constexpr uint32_t GroupCount = 5; + static constexpr uint32_t Stages = 2; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} +#endif diff --git a/test/unit/pipeline/testbed.h b/test/unit/pipeline/testbed.h new file mode 100644 index 0000000000..b809e74324 --- /dev/null +++ b/test/unit/pipeline/testbed.h @@ -0,0 +1,145 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Common Testbed file shared by Pipeline unit tests +*/ + +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" +#include "../common/cutlass_unit_test.h" + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + #define CUTLASS_UNIT_TEST_PIPELINE true +#else + #define CUTLASS_UNIT_TEST_PIPELINE false +#endif + +// Command line test options +struct Options { + // + // Data Members + // + bool help; + bool verification_enabled; + int SM_count; + int clock_MHz; + + // + // Methods + // + Options(): + help(false), + verification_enabled(true), + SM_count(116), + clock_MHz(1477) + { } + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("verification-enabled", verification_enabled, true); + cmd.get_cmd_line_argument("sm-count", SM_count, 116); + cmd.get_cmd_line_argument("clock", clock_MHz, 1477); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --verification-enabled= Enable/Disable verification\n" + << " --sm-count= Number of SMs on the chip\n" + << " --clock= Locked clock value in Mhz\n"; + + return out; + } +}; + +// +// Testbed +// + +template +struct Testbed { +private: + // Commandline options + Options options; + + void run_test(uint32_t const kNumIters) { + + // Run CuTe Gemm + Pipeline pipeline; + + cudaError_t result = pipeline.run(kNumIters); + + CUTE_CHECK_LAST(); + } + + +public: + Testbed(Options const &options_) : options(options_) { + int device_id = 0; + cudaDeviceProp device_prop; + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + } + + /// Run verification Gemm problem sizes + bool verification() { + + std::array kNumIters; + + for (int i = 0; i < kNumIters.size(); ++i) { + kNumIters[i] = (rand() % 1000) + 1; + } + + for (int n : kNumIters) { + std::cout << "Stages = " << Pipeline::Stages << " kNumIters = " << n << "\n"; + run_test(n); + } + + return true; + } +}; diff --git a/test/unit/util/CMakeLists.txt b/test/unit/util/CMakeLists.txt index 4e3f197286..449d6f62a8 100644 --- a/test/unit/util/CMakeLists.txt +++ b/test/unit/util/CMakeLists.txt @@ -29,9 +29,5 @@ cutlass_test_unit_add_executable( cutlass_test_unit_util tensor_reduce.cu - ) - -cutlass_test_unit_add_executable( - cutlass_test_unit_levels cutlass_test_levels.cu ) diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index d40eebc5d1..6bb3f79965 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -354,6 +354,9 @@ struct TileDescription { /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. int maximum_compute_capability; + /// Describes the shape of a cluster (in blocks) + cutlass::gemm::GemmCoord cluster_shape; + // // Methods // @@ -364,14 +367,16 @@ struct TileDescription { cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), MathInstructionDescription math_instruction = MathInstructionDescription(), int minimum_compute_capability = 0, - int maximum_compute_capability = 0 + int maximum_compute_capability = 0, + cutlass::gemm::GemmCoord cluster_shape = cutlass::gemm::GemmCoord(1,1,1) ): threadblock_shape(threadblock_shape), threadblock_stages(threadblock_stages), warp_count(warp_count), math_instruction(math_instruction), minimum_compute_capability(minimum_compute_capability), - maximum_compute_capability(maximum_compute_capability) { } + maximum_compute_capability(maximum_compute_capability), + cluster_shape(cluster_shape) { } // Equality operator inline @@ -991,6 +996,9 @@ struct GemmUniversalConfiguration { }; struct GemmUniversalArguments { + // NOTE: these are replicated for 3.0 interfaces + gemm::GemmCoord problem_size; + int batch_count; void const *A; void const *B; @@ -1001,6 +1009,12 @@ struct GemmUniversalArguments { void const *beta; ScalarPointerMode pointer_mode; + // NOTE: these are replicated for 3.0 interfaces + int64_t lda; + int64_t ldb; + int64_t ldc; + int64_t ldd; + int64_t batch_stride_A; int64_t batch_stride_B; int64_t batch_stride_C; diff --git a/tools/library/scripts/__init__.py b/tools/library/scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/library/scripts/gemm_operation.py b/tools/library/scripts/gemm_operation.py index 45786089b9..e4c86a710d 100644 --- a/tools/library/scripts/gemm_operation.py +++ b/tools/library/scripts/gemm_operation.py @@ -25,6 +25,7 @@ class GemmOperation: def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8): + self.prefix = "3x" if gemm_kind == GemmKind.Universal3x else "" self.operation_kind = OperationKind.Gemm self.arch = arch self.tile_description = tile_description @@ -83,7 +84,11 @@ def core_name(self): math_op = self.tile_description.math_instruction.math_operation math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' - inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + if self.gemm_kind == GemmKind.Universal3x: + inst_shape = "{0}x{1}x{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) + else: + inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) + inst_shape += math_op_string if self.tile_description.math_instruction.element_a != self.A.element and \ @@ -92,7 +97,7 @@ def core_name(self): return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind]) - # + # Generates a string representing the MMA instruction. def extended_name(self): ''' Append data types if they differ from compute type. ''' if self.is_complex(): @@ -115,7 +120,17 @@ def extended_name(self): return extended_name - # + def extended_name_3x(self): + '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' + extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}".format( + element_a = DataTypeNames[self.A.element], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator], + element_c = DataTypeNames[self.C.element], + core_name = self.core_name()) + return extended_name + + # Generates a short string representing the AB layout tags (e.g. nt or tn) def layout_name(self): if self.is_complex() or self.is_planar_complex(): return "%s%s" % ( @@ -124,25 +139,48 @@ def layout_name(self): ) return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) - # + # Generates a short string representing the ABC layout tags (e.g. ntn or tnn) + def layout_name_3x(self): + if self.is_complex() or self.is_planar_complex(): + return "{}{}{}".format( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)], + ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)]) + else: + return "{}{}{}".format( + ShortLayoutTypeNames[self.A.layout], + ShortLayoutTypeNames[self.B.layout], + ShortLayoutTypeNames[self.C.layout]) + + # Generates the full kernel function name def procedural_name(self): ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - threadblock = self.tile_description.procedural_name() - opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - - alignment = max([self.A.alignment, self.B.alignment, self.C.alignment]) - - return SubstituteTemplate( - "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}", - { - 'opcode_class': opcode_class_name, - 'extended_name': self.extended_name(), - 'threadblock': threadblock, - 'layout': self.layout_name(), - 'alignment': "%d" % self.A.alignment, - } - ) + if self.arch >= 90: + kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}" + return kernel_name_template.format( + p = self.prefix, + ar = self.arch, + op = opcode_class_name, + ex = self.extended_name_3x(), + tbm = self.tile_description.threadblock_shape[0], + tbn = self.tile_description.threadblock_shape[1], + tbk = self.tile_description.threadblock_shape[2], + cm = self.tile_description.cluster_shape[0], + cn = self.tile_description.cluster_shape[1], + ck = self.tile_description.cluster_shape[2], + l = self.tile_description.stages, + s = self.layout_name_3x(), + al = str(max(self.A.alignment, self.B.alignment))) + else: + threadblock = self.tile_description.procedural_name() + return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format( + p = self.prefix, + op = opcode_class_name, + ex = self.extended_name(), + tb = threadblock, + l = self.layout_name(), + a = str(self.A.alignment)) # def configuration_name(self): @@ -551,6 +589,142 @@ def emit(self, operation): return SubstituteTemplate(gemm_template, values) + +################################################################################################### + +# +class EmitGemmUniversal3xInstance: + ''' Responsible for emitting a CUTLASS 3.x template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/gemm/gemm.h", + "cutlass/numeric_types.h", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/epilogue/collective/default_epilogue.hpp", + "cutlass/epilogue/thread/linear_combination.h", + ] + self.builtin_epilogue_functor_template = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + > +""" + self.gemm_template = """ + +using ${operation_name}_mainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + +using ${operation_name}_epilogue = + cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t<${layout_c}>, + cutlass::gemm::TagToStrideC_t<${layout_c}>, + cutlass::epilogue::thread::LinearCombination< + ${element_c}, ${epilogue_vector_length}, ${element_accumulator}, ${element_epilogue}> + >; + +// Gemm operator ${operation_name} +using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + ${operation_name}_mainloop, + ${operation_name}_epilogue>; + +// Define named type +struct ${operation_name} : + public ${operation_name}_base { }; + +""" + # + def instance_template(self): + return """ +${compile_guard_start} + using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>; + manifest.append( + new ${gemm_kind}("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + # stage count set to zero indicates builder automatic stage selection + if operation.tile_description.stages > 0: + stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>" + else: + stage_count_string = "cutlass::gemm::collective::StageCountAuto" + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + + # 3.0 profiler integration only supports trivial epilogues for now + epilogue_vector_length = 1 + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + values = { + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + } + epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + # + + values = { + 'operation_name': operation.procedural_name(), + 'operation_suffix': self.operation_suffix, + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[instance_layout_A], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[instance_layout_B], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[instance_layout_C], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'cluster_m': str(operation.tile_description.cluster_shape[0]), + 'cluster_n': str(operation.tile_description.cluster_shape[1]), + 'cluster_k': str(operation.tile_description.cluster_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_functor': epilogue_functor, + 'stages': stage_count_string, + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + } + + return SubstituteTemplate(self.gemm_template, values) + ################################################################################################### # @@ -868,6 +1042,7 @@ def __init__(self, operation_path, configuration_name): GemmKind.Gemm: EmitGemmInstance, GemmKind.Sparse: EmitSparseGemmInstance, GemmKind.Universal: EmitGemmUniversalInstance, + GemmKind.Universal3x: EmitGemmUniversal3xInstance, GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance, GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance, GemmKind.Grouped: EmitGemmGroupedInstance @@ -877,6 +1052,7 @@ def __init__(self, operation_path, configuration_name): GemmKind.Gemm: 'GemmOperation', GemmKind.Sparse: 'GemmSparseOperation', GemmKind.Universal: 'GemmUniversalOperation', + GemmKind.Universal3x: 'GemmUniversal3xOperation', GemmKind.PlanarComplex: 'GemmPlanarComplexOperation', GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation', GemmKind.Grouped: 'GemmGroupedOperation' @@ -931,7 +1107,9 @@ def __enter__(self): ("cutlass/library/manifest.h", None), ("library_internal.h", None), ("gemm_operation.h", None), + ("gemm_operation_3x.hpp", None), ("cutlass/arch/wmma.h", None), + ("cutlass/numeric_types.h", None) ]) self.instance_definitions = [] self.instance_wrappers = [] diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 4e4faae7ad..6d5f8308fe 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -84,6 +84,43 @@ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ return operations + +# Generates 3.0 API based GemmUniversal API kernels. Alignment constraits are folded in with layouts +def CreateGemmUniversal3xOperator( + manifest, layouts, tile_descriptions, data_type, + complex_transforms=None, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity1): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none), ] + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + + for layout in layouts: + for tile_description in tile_descriptions: + for complex_transform in complex_transforms: + A = TensorDescription( + element_a, layout[0][0], layout[0][1], complex_transform[0]) + B = TensorDescription( + element_b, layout[1][0], layout[1][1], complex_transform[1]) + C = TensorDescription(element_c, layout[2][0], layout[2][1]) + + operation = GemmOperation( + GemmKind.Universal3x, tile_description.minimum_compute_capability, + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(operation) + operations.append(operation) + + return operations + # def CreateSparseGemmOperator(manifest, layouts, tile_descriptions, data_type, \ alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ @@ -3959,6 +3996,187 @@ def GenerateSM80(manifest, cuda_version): ################################################################################################### +# +def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = [ + MathInstruction( + [64, 128, 16], + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 16], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 16], + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + + min_cc = 90 + max_cc = 90 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed) + + +# +def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments + layouts_tf32 = [ + [[LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 1], [LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1]], + ] + + math_inst = MathInstruction( + [64, 128, 8], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 90 + + tile_descriptions = [ + TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + + data_type_tf32 = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmUniversal3xOperator(manifest, layouts_tf32, tile_descriptions, data_type_tf32) + + # F32 kernel, TN only supported for now + layouts_f32 = [layouts_tf32[2]] + + data_type_f32 = [ + DataType.f32, + DataType.f32, + math_inst.element_accumulator, + DataType.f32, + ] + + CreateGemmUniversal3xOperator(manifest, layouts_f32, tile_descriptions, data_type_f32) + + +def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = [ + MathInstruction( + [64, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 32], + DataType.u8, DataType.u8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + + min_cc = 90 + max_cc = 90 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type) + # def GenerateSM90_TensorOp_1684(manifest, cuda_version): @@ -3972,11 +4190,10 @@ def GenerateSM90_TensorOp_1684(manifest, cuda_version): (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), ] - math_inst = \ - MathInstruction( \ - [16, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ + math_inst = MathInstruction( + [16, 8, 4], + DataType.f64, DataType.f64, DataType.f64, + OpcodeClass.TensorOp, MathOperation.multiply_add) min_cc = 90 @@ -4002,7 +4219,7 @@ def GenerateSM90_TensorOp_1684(manifest, cuda_version): data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] - CreateGemmOperator(manifest, layouts, tile_descriptions, \ + CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints) # @@ -4564,11 +4781,12 @@ def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version): # def GenerateSM90(manifest, cuda_version): - + GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version) GenerateSM90_TensorOp_1684(manifest, cuda_version) GenerateSM90_TensorOp_1684_complex(manifest, cuda_version) GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version) - GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version) GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version) GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version) diff --git a/tools/library/scripts/library.py b/tools/library/scripts/library.py index 3dd57409c3..6919479e40 100644 --- a/tools/library/scripts/library.py +++ b/tools/library/scripts/library.py @@ -51,6 +51,8 @@ class DataType(enum.Enum): s16 = enum_auto() s32 = enum_auto() s64 = enum_auto() + e4m3 = enum_auto() + e5m2 = enum_auto() f16 = enum_auto() bf16 = enum_auto() f32 = enum_auto() @@ -76,6 +78,8 @@ class DataType(enum.Enum): # ShortDataTypeNames = { DataType.s32: 'i', + DataType.e4m3: 'e4m3', + DataType.e5m2: 'e5m2', DataType.f16: 'h', DataType.f32: 's', DataType.f64: 'd', @@ -96,6 +100,8 @@ class DataType(enum.Enum): DataType.s16: "s16", DataType.s32: "s32", DataType.s64: "s64", + DataType.e4m3: 'e4m3', + DataType.e5m2: 'e5m2', DataType.f16: "f16", DataType.bf16: "bf16", DataType.f32: "f32", @@ -130,6 +136,8 @@ class DataType(enum.Enum): DataType.s16: "int16_t", DataType.s32: "int32_t", DataType.s64: "int64_t", + DataType.e4m3: 'cutlass::float_e4m3_t', + DataType.e5m2: 'cutlass::float_e5m2_t', DataType.f16: "cutlass::half_t", DataType.bf16: "cutlass::bfloat16_t", DataType.f32: "float", @@ -164,6 +172,8 @@ class DataType(enum.Enum): DataType.s16: 16, DataType.s32: 32, DataType.s64: 64, + DataType.e4m3: 8, + DataType.e5m2: 8, DataType.f16: 16, DataType.bf16: 16, DataType.f32: 32, @@ -464,13 +474,15 @@ class Target(enum.Enum): 70: 'volta', 75: 'turing', 80: 'ampere', + 89: 'ada', + 90: 'hopper' } # SharedMemPerCC = { - 70: 96, # 96KB of SMEM - 72: 96, # 96KB of SMEM - 75: 64, # 64KB of SMEM + 70: 96, # 96KB of SMEM + 72: 96, # 96KB of SMEM + 75: 64, # 64KB of SMEM 80: 163, # 163KB of SMEM - 1KB reserved for the driver 86: 99, # 99KB of SMEM - 1KB reserved for the driver 87: 163, # 163KB of SMEM - 1KB reserved for the driver @@ -501,6 +513,7 @@ class GemmKind(enum.Enum): Gemm = enum_auto() Sparse = enum_auto() Universal = enum_auto() + Universal3x = enum_auto() PlanarComplex = enum_auto() PlanarComplexArray = enum_auto() Grouped = enum_auto() @@ -510,6 +523,7 @@ class GemmKind(enum.Enum): GemmKind.Gemm: "gemm", GemmKind.Sparse: "spgemm", GemmKind.Universal: "gemm", + GemmKind.Universal3x: "gemm", GemmKind.PlanarComplex: "gemm_planar_complex", GemmKind.PlanarComplexArray: "gemm_planar_complex_array", GemmKind.Grouped: "gemm_grouped" @@ -697,16 +711,60 @@ def __init__(self, instruction_shape, element_a, element_b, element_accumulator, # class TileDescription: - def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute): + def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1]): self.threadblock_shape = threadblock_shape self.stages = stages self.warp_count = warp_count self.math_instruction = math_instruction self.minimum_compute_capability = min_compute self.maximum_compute_capability = max_compute + self.cluster_shape = cluster_shape def procedural_name(self): - return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) + if self.minimum_compute_capability >= 90: + return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format( + tbm = self.threadblock_shape[0], + tbn = self.threadblock_shape[1], + tbk = self.threadblock_shape[2], + cm = self.cluster_shape[0], + cn = self.cluster_shape[1], + ck = self.cluster_shape[2], + s = self.stages) + else: + return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) + +# +class Direct2dConvFixedStrideDilationTileDescription: + def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): + self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] + self.threadblock_output_shape = threadblock_output_shape + self.filter_shape = filter_shape + self.stages = stages + self.warp_count = warp_count + self.stride = stride + self.dilation = dilation + self.math_instruction = math_instruction + self.minimum_compute_capability = min_compute + self.maximum_compute_capability = max_compute + + def procedural_name(self): + str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + self.threadblock_output_shape[0], + self.threadblock_output_shape[1], + self.threadblock_output_shape[2], + self.threadblock_output_shape[3], + self.stages, + self.filter_shape[0], + self.filter_shape[1]) + # Fixed Strided and dilation + if self.stride != [-1, -1] and self.dilation != [-1, -1]: + str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], + self.stride[1], + self.dilation[0], + self.dilation[1]) + return str_name # class Direct2dConvFixedStrideDilationTileDescription: diff --git a/tools/library/scripts/manifest.py b/tools/library/scripts/manifest.py index 345ecd45a1..966f418e1f 100644 --- a/tools/library/scripts/manifest.py +++ b/tools/library/scripts/manifest.py @@ -204,7 +204,10 @@ def __init__(self, args = None): if self.args: self.kernel_filter = self.args.kernels self.curr_build_dir = args.curr_build_dir + architectures = args.architectures.split(';') if len(args.architectures) else ['50',] + architectures = [x if x != '90a' else '90' for x in architectures] + self.compute_capabilities = [int(x) for x in architectures] if args.filter_by_cc in ['false', 'False', '0']: @@ -348,6 +351,8 @@ def append(self, operation): self.operations[operation.operation_kind][configuration_name].append(operation) self.operation_count += 1 + else: + print("Culled {} from manifest".format(operation.procedural_name())) # # diff --git a/tools/library/scripts/pycutlass/README.md b/tools/library/scripts/pycutlass/README.md index 2843298b45..dd2e7d0e4e 100644 --- a/tools/library/scripts/pycutlass/README.md +++ b/tools/library/scripts/pycutlass/README.md @@ -81,13 +81,24 @@ The tiling size of above operations can also be customized. ## Installation ### Using Docker -You can run the PyCUTLASS on NGC PyTorch container. +We recommend using one of our provided Docker images for using PyCUTLASS. + +**To run CUTLASS 3 GEMM kernels targetting the NVIDIA Hopper architecture via PyCUTLASS,** you can use an included [Dockerfile](docker/Dockerfile-cuda12.0) based on the NGC CUDA 12.0 container: +```shell +docker build -t pycutlass-cuda12.0:latest -f docker/Dockerfile-cuda12.0 . +docker run --gpus all -it --rm pycutlass-cuda12.0:latest +``` +Note that this Docker container does not include CuPy or PyTorch, and, thus, will not be able to run PyCUTLASS examples that +leverage these packages. + +**To run CUTLASS 2.x kernels targetting pre-SM90 architectures via PyCUTLASS,** you can use an included [Dockerfile](docker/Dockerfile-cuda11.8-pytorch) based on an NGC PyTorch container: ```shell -docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:22.09-py3 +docker build -t pycutlass-cuda11.8-pytorch:latest -f docker/Dockerfile-cuda11.8-pytorch . +docker run --gpus all -it --rm pycutlass-cuda11.8-pytorch:latest ``` ### Environment variables -PyCUTLASSS requires two environment variables: +PyCUTLASS requires two environment variables: * `CUTLASS_PATH`: the root directory of CUTLASS. You can set this from the location at which you cloned CUTLASS via: `export CUTLASS_PATH=$(pwd)`. * `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed. If running in bash with `nvcc` installed under a CUDA toolkit, you can set this to the location of your `nvcc` installation via: `export CUDA_INSTALL_PATH=$(which nvcc | awk -F'/bin/nvcc' '{print $1}')` diff --git a/tools/library/scripts/pycutlass/build.sh b/tools/library/scripts/pycutlass/build.sh index cffc85a645..5dbda5d505 100644 --- a/tools/library/scripts/pycutlass/build.sh +++ b/tools/library/scripts/pycutlass/build.sh @@ -1,4 +1,36 @@ -pip install pybind11 +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +pip install -U pybind11 git clone https://github.com/google/googletest.git -python setup.py install +python setup.py develop --user python setup.py rmm diff --git a/tools/library/scripts/pycutlass/build_doc.sh b/tools/library/scripts/pycutlass/build_doc.sh index aa7ef7c794..3fad0808e3 100644 --- a/tools/library/scripts/pycutlass/build_doc.sh +++ b/tools/library/scripts/pycutlass/build_doc.sh @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + pip install enum-tools pip install sphinx-toolbox pip install m2r2 diff --git a/tools/library/scripts/pycutlass/docker/Dockerfile-cuda11.8-pytorch b/tools/library/scripts/pycutlass/docker/Dockerfile-cuda11.8-pytorch new file mode 100644 index 0000000000..c36e0e2ec2 --- /dev/null +++ b/tools/library/scripts/pycutlass/docker/Dockerfile-cuda11.8-pytorch @@ -0,0 +1,40 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +FROM nvcr.io/nvidia/pytorch:22.11-py3 + +RUN chmod ugo+rwx /home +RUN pip uninstall -y rmm +RUN pip install rmm-cu11 --extra-index-url=https://pypi.ngc.nvidia.com +ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH +ENV LIBRARY_PATH=/usr/local/cuda/lib64:$LIBRARY_PATH +ENV CUDA_INSTALL_PATH=/usr/local/cuda diff --git a/tools/library/scripts/pycutlass/docker/Dockerfile-cuda12.0 b/tools/library/scripts/pycutlass/docker/Dockerfile-cuda12.0 new file mode 100644 index 0000000000..f81d79d01c --- /dev/null +++ b/tools/library/scripts/pycutlass/docker/Dockerfile-cuda12.0 @@ -0,0 +1,46 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +FROM nvcr.io/nvidia/cuda:12.0.0-devel-ubuntu20.04 + +RUN apt-get update +RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata +RUN apt-get install -y git cmake vim python3 python3-pip +RUN ln -s /usr/bin/python3 /usr/bin/python +RUN chmod ugo+rwx /home +RUN pip install numpy==1.23 +RUN pip install cudf-cu11 dask-cudf-cu11 --extra-index-url=https://pypi.ngc.nvidia.com +RUN pip install cuml-cu11 --extra-index-url=https://pypi.ngc.nvidia.com +RUN pip install cugraph-cu11 --extra-index-url=https://pypi.ngc.nvidia.com +ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH +ENV LIBRARY_PATH=/usr/local/cuda/lib64:/usr/lib/x86_64-linux-gnu/:$LIBRARY_PATH +ENV CUDA_INSTALL_PATH=/usr/local/cuda diff --git a/tools/library/scripts/pycutlass/setup.py b/tools/library/scripts/pycutlass/setup.py index 219face050..bf950ae81a 100644 --- a/tools/library/scripts/pycutlass/setup.py +++ b/tools/library/scripts/pycutlass/setup.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + import distutils.cmd from setuptools import setup import setuptools.command.build_py @@ -15,7 +47,7 @@ def run(self): import rmm except ImportError: print("installing rmm") - os.system("git clone -b branch-22.08 --recurse-submodules https://github.com/rapidsai/rmm.git") + os.system("git clone -b branch-22.10 --recurse-submodules https://github.com/rapidsai/rmm.git") os.chdir("./rmm") os.system("./build.sh librmm rmm") os.chdir("./python") @@ -43,7 +75,11 @@ def run(self): Pybind11Extension("cutlass", ["src/cpp/cutlass.cpp"], include_dirs=include_dirs, - extra_compile_args=["-fpermissive", "-w"]) + extra_compile_args=["-fpermissive", "-w", "-std=c++17"]), + Pybind11Extension("cute", + ["src/cpp/cute.cpp"], + include_dirs=include_dirs, + extra_compile_args=["-fpermissive", "-w", "-std=c++17"]) ] except ImportError: pass @@ -65,7 +101,7 @@ def run(self): install_requires=[ "numpy<1.23", 'pybind11', - 'cuda-python<11.7.0', + 'cuda-python>=11.8.0', 'typeguard', 'bfloat16', 'typing', diff --git a/tools/library/scripts/pycutlass/src/cpp/cute.cpp b/tools/library/scripts/pycutlass/src/cpp/cute.cpp new file mode 100644 index 0000000000..8995159e10 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/cute.cpp @@ -0,0 +1,54 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief binding CuTe C++ APIs to Python +*/ + +#include +#include + +#include "cute/arch/mma_sm90_gmma.hpp" + +namespace py = pybind11; + + +PYBIND11_MODULE(cute, m) { + + // module doc + m.doc() = "CuTe C++ bindings"; + + py::enum_(m, "GMMAMajor", + R"pbdoc(classification of CuTe GMMA tensor major specification)pbdoc") + .value("K", cute::GMMA::Major::K, + R"pbdoc(Tensor is contiguous in reduction dimension)pbdoc") + .value("MN", cute::GMMA::Major::MN, + R"pbdoc(Tensor is contiguous in non-reduction dimension)pbdoc"); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/cutlass.cpp b/tools/library/scripts/pycutlass/src/cpp/cutlass.cpp index 5971b9138c..9e4718826e 100644 --- a/tools/library/scripts/pycutlass/src/cpp/cutlass.cpp +++ b/tools/library/scripts/pycutlass/src/cpp/cutlass.cpp @@ -29,8 +29,9 @@ * **************************************************************************************************/ /* \file - \brief binding cutlass C++ APIs to python + \brief binding CUTLASS C++ APIs to Python */ + #include #include diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h index cd0076db35..6b33f9a350 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h @@ -34,6 +34,7 @@ \brief A generic wrapper around an epilogue visitor operation */ + #pragma once #include "cutlass/cutlass.h" diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h index c2fee45ba6..f64066a0eb 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h @@ -30,8 +30,8 @@ **************************************************************************************************/ /*! \file - - \brief Binary operations to be used within the epilogue visitor model. + + \brief A file contains the binary ops */ #pragma once @@ -44,7 +44,7 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Elementwise addition of two arrays +/// Scalar multiplication template struct VectorAdd { diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h index 90d4601a91..9952a52bc8 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h @@ -30,8 +30,8 @@ **************************************************************************************************/ /*! \file - - \brief Unary operations to be used within the epilogue visitor model. + + \brief A file contains the unary ops */ #pragma once diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h index 519fb23970..2072cfaf26 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h @@ -30,8 +30,8 @@ **************************************************************************************************/ /*! \file - - \brief Epilogue visitor operation that simply returns the accumulator + + \brief A file contains the epilogue visitor Op with accumulator */ #pragma once diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h index 34daa27f26..d9fa4458b1 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h @@ -30,8 +30,8 @@ **************************************************************************************************/ /*! \file - - \brief Epilogue visitor operator performing a binary operation between two visitor nodes + + \brief A file contains the epilogue visitor Op with Binary op */ #pragma once @@ -84,7 +84,6 @@ class VisitorOpBinary{ /// Fragment type of accumulator using AccumulatorAccessType = Array; - /// Combination Op TODO: generalize this using BinaryOp = BinaryOp_; static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A"); diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h index 5c1c6938ca..6dcb32b27f 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h @@ -30,8 +30,8 @@ **************************************************************************************************/ /*! \file - - \brief Epilogue visitor operation that broadcasts a vector to all columns + + \brief A file contains the epilogue visitor Op with broadcasting vector to all columns */ #pragma once diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h index 6730ba0c2c..624d7e681d 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h @@ -30,8 +30,8 @@ **************************************************************************************************/ /*! \file - - \brief Epilogue visitor operation that performs a column-wise reduction within a threadblock + + \brief A file contains the epilogue visitor Op with reduction over columns in CTA */ #pragma once @@ -68,7 +68,6 @@ class VisitorOpColumnReduction { static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - // TODO: generalize the reduction op using ReductionOp = cutlass::plus>; using ReductionOpScalar = cutlass::plus; using ElementOutput = typename OutputTileIterator::Element; diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h index da38829801..1e2b8e61d4 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h @@ -30,8 +30,8 @@ **************************************************************************************************/ /*! \file - - \brief Epilogue visitor operation that performs a linear combination of two visitor nodes + + \brief A file contains the epilogue visitor Op with Linear Combination */ #pragma once @@ -82,7 +82,7 @@ class VisitorOpLinearCombination{ /// Fragment type of accumulator using AccumulatorAccessType = Array; - /// Combination Op TODO: generalize this + /// Combination Op using CombinationOp = cutlass::plus; static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A"); diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h index c0acb7bfbb..dc7bfa2f49 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h @@ -30,8 +30,8 @@ **************************************************************************************************/ /*! \file - - \brief Epilogue visitor operation that broadcasts a vector to all rows + + \brief A file contains the epilogue visitor Op with broadcasting vector to all rows */ #pragma once diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h index 33420e857e..27b03f843a 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h @@ -30,8 +30,8 @@ **************************************************************************************************/ /*! \file - - \brief Epilogue visitor operation that performs a column-wise reduction within a threadblock + + \brief A file contains the epilogue visitor Op with reduction over rows in CTA */ #pragma once @@ -69,7 +69,6 @@ class VisitorOpRowReduction { static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - // TODO: generalize the reduction op using ReductionOp = cutlass::plus>; using ReductionOpScalar = cutlass::plus; using ElementOutput = typename OutputTileIterator::Element; diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h index 1a95b41e2d..c80543ea3a 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h @@ -30,8 +30,8 @@ **************************************************************************************************/ /*! \file - - \brief Epilogue visitor operator performing a unary operation atop a visitor node + + \brief A file contains the epilogue visitor Op with Unary operation */ #pragma once @@ -79,7 +79,7 @@ class VisitorOpUnary{ /// Fragment type of accumulator using AccumulatorAccessType = Array; - /// Combination Op TODO: generalize this + /// Combination Op using UnaryOp = UnaryOp_; static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor"); diff --git a/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h b/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h index 758814bac7..64b65a03a8 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h @@ -30,7 +30,7 @@ **************************************************************************************************/ /*! \file - \brief + \brief */ #pragma once @@ -139,8 +139,8 @@ struct GemmUniversalwithEpilogueVisitor { // // Methods // - - Arguments(): + + Arguments(): ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), ptr_gather_A_indices(nullptr), ptr_gather_B_indices(nullptr), @@ -169,8 +169,8 @@ struct GemmUniversalwithEpilogueVisitor { int const *ptr_scatter_D_indices = nullptr ): UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), - epilogue_visitor(epilogue_visitor), - ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + epilogue_visitor(epilogue_visitor), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), @@ -205,8 +205,8 @@ struct GemmUniversalwithEpilogueVisitor { int const *ptr_scatter_D_indices = nullptr ): UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), - epilogue_visitor(epilogue_visitor), - ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + epilogue_visitor(epilogue_visitor), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), @@ -221,7 +221,7 @@ struct GemmUniversalwithEpilogueVisitor { /// Returns arguments for the transposed problem Arguments transposed_problem() const { Arguments args(*this); - + std::swap(args.problem_size.m(), args.problem_size.n()); std::swap(args.ptr_A, args.ptr_B); std::swap(args.lda, args.ldb); @@ -256,7 +256,7 @@ struct GemmUniversalwithEpilogueVisitor { typename Mma::IteratorB::Params params_B; typename EpilogueVisitor::OutputTileIterator::Params params_C; typename EpilogueVisitor::OutputTileIterator::Params params_D; - + typename EpilogueVisitor::Params epilogue_visitor; void * ptr_A; @@ -325,7 +325,7 @@ struct GemmUniversalwithEpilogueVisitor { batch_stride_C = args.batch_stride_C; epilogue_visitor = args.epilogue_visitor; - + semaphore = static_cast(workspace); CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); } @@ -345,7 +345,7 @@ struct GemmUniversalwithEpilogueVisitor { // CUTLASS_DEVICE - GemmUniversalwithEpilogueVisitor() { } + GemmUniversalwithEpilogueVisitor() { } /// Determines whether kernel satisfies alignment static Status can_implement( @@ -455,12 +455,12 @@ struct GemmUniversalwithEpilogueVisitor { // // Fetch pointers based on mode. // - if (params.mode == GemmUniversalMode::kGemm || + if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) { if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; } offset_k = threadblock_tile_offset.k() * params.gemm_k_size; @@ -529,10 +529,10 @@ struct GemmUniversalwithEpilogueVisitor { // Compute threadblock-scoped matrix multiply-add mma( - gemm_k_iterations, - accumulators, - iterator_A, - iterator_B, + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, accumulators); // @@ -555,30 +555,16 @@ struct GemmUniversalwithEpilogueVisitor { int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_C = static_cast(params.ptr_C); ElementC *ptr_D = static_cast(params.ptr_D); // // Fetch pointers based on mode. // - + // Construct the semaphore. Semaphore semaphore(params.semaphore + block_idx, thread_idx); - // if (params.mode == GemmUniversalMode::kGemm) { - - // // TODO: fix this order - // // If performing a reduction via split-K, fetch the initial synchronization - // if (params.grid_tiled_shape.k() > 1) { - - // // Fetch the synchronization lock initially but do not block. - // semaphore.fetch(); - - // // Indicate which position in a serial reduction the output operator is currently updating - // output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - // } - // } - // Tile iterator loading from source tensor. EpilogueVisitor epilogue_visitor( @@ -590,9 +576,6 @@ struct GemmUniversalwithEpilogueVisitor { params.problem_size.mn() ); - // if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { - // ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; - // } if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); } @@ -605,25 +588,20 @@ struct GemmUniversalwithEpilogueVisitor { // Wait on the semaphore - this latency may have been covered by iterator construction if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - // TODO: ??? - // if (threadblock_tile_offset.k()) { - // iterator_C = iterator_D; - // } + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. semaphore.wait(threadblock_tile_offset.k()); } // Execute the epilogue operator to update the destination tensor. - epilogue(epilogue_visitor, accumulators); - + epilogue(epilogue_visitor, accumulators); + // // Release the semaphore // - if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { int lock = 0; if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { @@ -635,7 +613,7 @@ struct GemmUniversalwithEpilogueVisitor { // Otherwise, the semaphore is incremented lock = threadblock_tile_offset.k() + 1; } - + semaphore.release(lock); } } diff --git a/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h b/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h index 830ff76680..43991e4658 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h @@ -83,7 +83,6 @@ void bind_identity_swizzle(py::module & m, std::string name) { :param problem_size: Implicit gemm problem size conv_operator(NZPQK, NDHWC, KTRSC) :type problem_size: :class:`cutlass.gemm.GemmCoord`) )pbdoc") - // TODO: the returned dim3 is not usable in python .def("get_grid_shape", &T::get_grid_shape, py::arg("tiled_shape"), R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") diff --git a/tools/library/scripts/pycutlass/src/pycutlass/__init__.py b/tools/library/scripts/pycutlass/src/pycutlass/__init__.py index 8972c6fa65..18f3e84db7 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/__init__.py @@ -31,6 +31,7 @@ def SubstituteTemplate(template, values): from pycutlass.frontend import * from pycutlass.reduction_operation import * from pycutlass.compiler import * +from pycutlass.utils.device import device_cc # module-wide variables @@ -40,6 +41,12 @@ def SubstituteTemplate(template, values): # artifact manager this.compiler = ArtifactManager() +try: + if not hasattr(this, 'DEVICE_CC') or this.DEVICE_CC is None: + this.DEVICE_CC = device_cc() +except: + this.DEVICE_CC = None + def get_memory_pool(init_pool_size=0, max_pool_size=2**34): this.memory_pool = PoolMemoryManager( init_pool_size=init_pool_size, diff --git a/tools/library/scripts/pycutlass/src/pycutlass/builder/collective_op_builder.py b/tools/library/scripts/pycutlass/src/pycutlass/builder/collective_op_builder.py new file mode 100644 index 0000000000..3e915261b5 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/builder/collective_op_builder.py @@ -0,0 +1,395 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for stamping out collective mainloops for SM90 kernels +""" + +import cute +import cutlass +from pycutlass import SubstituteTemplate +import pycutlass.library as library + + +tma_alignment_bytes = 16 +cp_async_min_alignment_bytes = 4 + + +class RowColMajorToGMMAMajor: + @staticmethod + def A(layout, element): + """ + Converts operand A's layout from row/column major format into CuTe's GMMA major format + + :param layout: layout of the A operand + :type layout: cutlass.RowMajor or cutlass.ColumnMajor + :param element: data type of the A operand + + :return: C++ CuTe GMMA major format + :rtype: cute.GMMAMajor + """ + type_requires_k_major = (element == cutlass.tfloat32) or (element == cutlass.int8) + if layout == cutlass.ColumnMajor and not type_requires_k_major: + return cute.GMMAMajor.MN + else: + return cute.GMMAMajor.K + + @staticmethod + def B(layout, element): + """ + Converts operand B's layout from row/column major format into CuTe's GMMA major format + + :param layout: layout of the B operand + :type layout: cutlass.RowMajor or cutlass.ColumnMajor + :param element: data type of the B operand + + :return: C++ CuTe GMMA major format + :rtype: cute.GMMAMajor + """ + type_requires_k_major = (element == cutlass.tfloat32) or (element == cutlass.int8) + if layout == cutlass.RowMajor and not type_requires_k_major: + return cute.GMMAMajor.MN + else: + return cute.GMMAMajor.K + + +def cluster_shape_to_tma(dim): + """ + Returns the TMA copy type for a given cluster dimension + + :param dim: a given dimension of a cluster + :type dim: layout + + :return: C++ TMA copy time + :rtype: str + """ + return 'cute::SM90_TMA_LOAD' if dim == 1 else 'cute::SM90_TMA_LOAD_MULTICAST' + + +def make_cpasync_gmem_tiled_copy(thread_count, element, alignment, gmma_layout, dim_mn, dim_k): + """ + Returns a `make_tiled_copy` call for a given configuraiton + + :param thread_count: number of threads in the threadblock + :type thread_count: int + :param element: datatype of the operand in question + :param alignment: byte alignment of the operand in question + :type alignment: int + :param gmma_layout: GMMA layout of the operand in question + :type gmma_layout: cute.GMMAMajor + :param dim_mn: extent of the M/N dimension of the tile + :type dim_mn: int + :param dim_k: extent of the reduction dimension of the tile + :type dim_k: int + + :return: C++ call to `make_tiled_copy` + :rtype: str + """ + + emission_str = """decltype(cute::make_tiled_copy( + cute::Copy_Atom(sizeof(${element})) * ${alignment}>>, ${element}>{}, + cute::Layout, + cute::Stride<_${stride_x}, _${stride_y}>>{}, + cute::Layout>{}))""" + if gmma_layout == cute.GMMAMajor.K: + threads_major = dim_k // alignment + threads_minor = thread_count // threads_major + values = { + 'shape0_x': str(threads_minor), + 'shape0_y': str(threads_major), + 'stride_x': str(threads_major), + 'stride_y': '1', + 'shape1_x': '1', + 'shape1_y': str(alignment) + } + elif gmma_layout == cute.GMMAMajor.MN: + threads_major = dim_mn // alignment + threads_minor = thread_count // threads_major + values = { + 'shape0_x': str(threads_major), + 'shape0_y': str(threads_minor), + 'stride_x': '1', + 'stride_y': str(threads_major), + 'shape1_x': str(alignment), + 'shape1_y': '1' + } + else: + raise Exception('Unexpected GMMA layout {}'.format(gmma_layout)) + + # Add common values + values['element'] = library.DataTypeTag[element] + values['alignment'] = str(alignment) + return SubstituteTemplate(emission_str, values) + + +def max_stages(op, arch): + """ + Returns the maximum number pipeline stages that can be used for an operation. + + :param op: operation for which the maximum stages should be computed. If stages are + set via the `op.tile_description.stages` parameter, this setting is ignored + in the present calculation + :type op: pycutlass.GemmOperation + :param arch: compute capability of the device on which the operation will be run + :type arch: int + + :return: maximum number of pipeline stages that can be used for an operation + :rtype: int + """ + smem_per_stage = library.CalculateSmemUsagePerStage(op) + smem_capacity = library.SharedMemPerCC[arch] + return int(smem_capacity // smem_per_stage) + + +class LayoutToStride: + _variable_first = 'cute::Stride, int64_t>' + _variable_last = 'cute::Stride, int64_t, int64_t>' + + @staticmethod + def A(layout): + """ + Returns the CuTe shape type corresponding to the layout of operand A + + :param layout: layout of the B operand + :type layout: cutlass.RowMajor or cutlass.ColumnMajor + + :return: C++ declaration of CuTe stride + :rtype: str + """ + if layout == cutlass.RowMajor: + return LayoutToStride._variable_first + elif layout == cutlass.ColumnMajor: + return LayoutToStride._variable_last + else: + raise Exception('Unsupported layout {}'.format(layout)) + + @staticmethod + def B(layout): + """ + Returns the CuTe shape type corresponding to the layout of operand B + + :param layout: layout of the B operand + :type layout: cutlass.RowMajor or cutlass.ColumnMajor + + :return: C++ declaration of CuTe stride + :rtype: str + """ + if layout == cutlass.RowMajor: + return LayoutToStride._variable_last + elif layout == cutlass.ColumnMajor: + return LayoutToStride._variable_first + else: + raise Exception('Unsupported layout {}'.format(layout)) + + +EMISSION_STR = """ +using TileShape_MNK = cute::Shape<_${threadblock_shape_m}, _${threadblock_shape_n}, _${threadblock_shape_k}>; +using ClusterShape_MNK = cute::Shape<_${cluster_shape_m}, _${cluster_shape_n}, _${cluster_shape_k}>; +using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ${internal_element_A}, ${internal_element_B}, ${element_accumulator}, TileShape_MNK, ${gmma_layout_A}, ${gmma_layout_B}>())); + +using SmemLayoutAtomA = decltype(cute::GMMA::smem_selector<${gmma_layout_A}, ${internal_element_A}, _${threadblock_shape_m}, _${threadblock_shape_k}>()); +using SmemLayoutAtomB = decltype(cute::GMMA::smem_selector<${gmma_layout_B}, ${internal_element_B}, _${threadblock_shape_n}, _${threadblock_shape_k}>()); + +using CollectiveOp = typename cutlass::gemm::collective::CollectiveMma< + ${mainloop_type}<${stage_count}, ClusterShape_MNK${kernel_schedule}>, + TileShape_MNK, + ${element_A}, + ${stride_A}, + ${element_B}, + ${stride_B}, + TiledMma, + ${gmem_tiled_copy_A}, + SmemLayoutAtomA, + void, // GMMA_SS does not need an SmemCopyAtom + ${transform_A}, + ${gmem_tiled_copy_B}, + SmemLayoutAtomB, + void, // GMMA_SS does not need an SmemCopyAtom + ${transform_B} +>; +""" + + +def internal_element(element): + """ + Returns the data type internally used for `element`. + + :param element: data type + + :return: data type used internally + """ + return cutlass.tfloat32 if element == cutlass.float32 else element + + +def common_values(op, stage_count, transform_A, transform_B): + """ + Returns a dictionary containing common values to be substituted in the emission of the + collective operation declaration. Values specific to a particular collective operation + should be added to these. + + :param op: GEMM operation for which to build a collective operation + :type op: pycutlass.GemmOperation + :param stage_count: number of pipeline stages to use in the operation + :type stage_count: int + :param transform_A: transformation to perform on the A operand + :type transform_A: str + :param transform_B: transformation to perform on the B operand + :type transform_B: str + + :return: dictionary containing values to substitute in emission string + :rtype: dict + """ + internal_element_a = internal_element(op.A.element) + internal_element_b = internal_element(op.B.element) + + return { + 'threadblock_shape_m': str(op.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(op.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(op.tile_description.threadblock_shape[2]), + 'cluster_shape_m': str(op.tile_description.cluster_shape[0]), + 'cluster_shape_n': str(op.tile_description.cluster_shape[1]), + 'cluster_shape_k': str(op.tile_description.cluster_shape[2]), + 'element_A': library.DataTypeTag[op.A.element], + 'element_B': library.DataTypeTag[op.B.element], + 'internal_element_A': library.DataTypeTag[internal_element_a], + 'internal_element_B': library.DataTypeTag[internal_element_b], + 'element_accumulator': library.DataTypeTag[op.accumulator_type()], + 'gmma_layout_A': library.CuTeLayoutTag[RowColMajorToGMMAMajor.A(op.A.layout, internal_element_a)], + 'gmma_layout_B': library.CuTeLayoutTag[RowColMajorToGMMAMajor.B(op.B.layout, internal_element_b)], + 'stride_A': LayoutToStride.A(op.A.layout), + 'stride_B': LayoutToStride.B(op.B.layout), + 'stage_count': str(stage_count), + 'transform_A': transform_A, + 'transform_B': transform_B + } + + +def build_gmma_tma(op): + """ + Builds a collective operation declaration targetting TMA GMMA kernels + + :param op: GEMM operation for which to build a collective operation + :type op: pycutlass.GemmOperation + + :return: string containing the C++ declaration of collective operation + :rtype: str + """ + A_tma_aligned = (library.DataTypeSizeBytes[op.A.element] * op.A.alignment) % tma_alignment_bytes == 0 + B_tma_aligned = (library.DataTypeSizeBytes[op.B.element] * op.B.alignment) % tma_alignment_bytes == 0 + if not A_tma_aligned or not B_tma_aligned: + raise Exception('Each of the A or B operands must be aligned to {} bytes to use TMA'.format(tma_alignment_bytes)) + + max_stage_count = max_stages(op, arch=90) + if op.tile_description.stages is None: + op.tile_description.stages = max_stage_count + elif op.tile_description.stages > max_stage_count: + raise Exception('Combination of threadblock shape, data types, and number of stages exceeds shared memory capacity.') + + kernel_schedule = 'cutlass::gemm::KernelTmaWarpSpecialized' + if op.tile_description.persistent: + kernel_schedule = 'cutlass::gemm::KernelTmaWarpSpecializedPersistent' + + transform_A = 'cute::identity' + transform_B = 'cute::identity' + values = common_values(op, op.tile_description.stages, transform_A, transform_B) + specific_values = { + 'mainloop_type': 'cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized', + 'kernel_schedule': ', ' + kernel_schedule, + 'gmem_tiled_copy_A': cluster_shape_to_tma(op.tile_description.cluster_shape[1]), + 'gmem_tiled_copy_B': cluster_shape_to_tma(op.tile_description.cluster_shape[0]) + } + values.update(specific_values) + + return SubstituteTemplate(EMISSION_STR, values) + + +def build_gmma_cpasync(op): + """ + Builds a collective operation declaration targetting cp.async GMMA kernels + + :param op: GEMM operation for which to build a collective operation + :type op: pycutlass.GemmOperation + + :return: string containing the C++ declaration of collective operation + :rtype: str + """ + A_cp_async_aligned = (library.DataTypeSizeBytes[op.A.element] * op.A.alignment) % cp_async_min_alignment_bytes == 0 + B_cp_async_aligned = (library.DataTypeSizeBytes[op.B.element] * op.B.alignment) % cp_async_min_alignment_bytes == 0 + if not A_cp_async_aligned or not B_cp_async_aligned: + raise Exception('Each of the A or B operands must be aligned to {} bytes to use cp.async'.format(cp_async_min_alignment_bytes)) + + max_stage_count = max_stages(op, arch=90) + if op.tile_description.stages is None: + op.tile_description.stages = max_stage_count + elif op.tile_description.stages > max_stage_count: + raise Exception('Combination of threadblock shape, data types, and number of stages exceeds shared memory capacity.') + + transform_A = 'cute::identity' + transform_B = 'cute::identity' + + thread_count = 128 + cpasync_copy_A = make_cpasync_gmem_tiled_copy(thread_count, op.A.element, op.A.alignment, RowColMajorToGMMAMajor.A(op.A.layout, op.A.element), + op.tile_description.threadblock_shape[0], op.tile_description.threadblock_shape[2]) + cpasync_copy_B = make_cpasync_gmem_tiled_copy(thread_count, op.B.element, op.B.alignment, RowColMajorToGMMAMajor.B(op.B.layout, op.B.element), + op.tile_description.threadblock_shape[1], op.tile_description.threadblock_shape[2]) + + values = common_values(op, op.tile_description.stages, transform_A, transform_B) + specific_values = { + 'mainloop_type': 'cutlass::gemm::MainloopSm90CpAsyncGmma', + 'kernel_schedule': '', + 'gmem_tiled_copy_A': cpasync_copy_A, + 'gmem_tiled_copy_B': cpasync_copy_B + } + values.update(specific_values) + + return SubstituteTemplate(EMISSION_STR, values) + + +def build(operation): + """ + Builds a collective operation declaration targetting cp.async or TMA for GMMA kernels + + :param operation: GEMM operation for which to build a collective operation + :type operation: pycutlass.GemmOperation + + :return: string containing the C++ declaration of collective operation + :rtype: str + """ + A_tma_aligned = (library.DataTypeSizeBytes[operation.A.element] * operation.A.alignment) % tma_alignment_bytes == 0 + B_tma_aligned = (library.DataTypeSizeBytes[operation.B.element] * operation.B.alignment) % tma_alignment_bytes == 0 + tma_correct_size = (library.DataTypeSizeBytes[operation.A.element] == 2 and library.DataTypeSizeBytes[operation.B.element] == 2) + tma_correct_layout = (operation.A.layout == cutlass.RowMajor or operation.B.layout == cutlass.ColumnMajor) + if A_tma_aligned and B_tma_aligned and (tma_correct_size or tma_correct_layout): + return build_gmma_tma(operation) + else: + return build_gmma_cpasync(operation) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/c_types.py b/tools/library/scripts/pycutlass/src/pycutlass/c_types.py index 5822ccc88b..e5e985143e 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/c_types.py @@ -33,8 +33,6 @@ import ctypes from pycutlass.library import * -# 12B - class GemmCoord_(ctypes.Structure): _fields_ = [ @@ -48,6 +46,24 @@ def __init__(self, gemm_coord) -> None: setattr(self, field_name, getattr(gemm_coord, field_name)()) +class GemmCoordBatched_(ctypes.Structure): + """ + Wrapper around a GemmCoord that also contains batch count. This is used for encoding + batched GEMM inputs to CUTLASS 3 GEMMs. + """ + _fields_ = [ + ("m", ctypes.c_int), + ("n", ctypes.c_int), + ("k", ctypes.c_int), + ("batch_count", ctypes.c_int) + ] + + def __init__(self, gemm_coord, batch_count) -> None: + for field_name, _ in self._fields_[:-1]: + setattr(self, field_name, getattr(gemm_coord, field_name)()) + setattr(self, "batch_count", batch_count) + + class MatrixCoord_(ctypes.Structure): _fields_ = [ ("row", ctypes.c_int), @@ -55,6 +71,26 @@ class MatrixCoord_(ctypes.Structure): ] +class dim3_(ctypes.Structure): + _fields_ = [ + ("x", ctypes.c_int), + ("y", ctypes.c_int), + ("z", ctypes.c_int) + ] + + +class StrideBatched_(ctypes.Structure): + """ + CUTLASS 3.0 strides for operands contain one static dimension and two variable dimensions. The + variable dimensions represent the stride along non-unit-stride dimension of the row/column major + layout, and the batch stride. This structure encodes the two variable dimensions. + """ + _fields_ = [ + ("major_stride", ctypes.c_int64), + ("batch_stride", ctypes.c_int64) + ] + + dtype2ctype = { cutlass.float16: ctypes.c_uint16, cutlass.float32: ctypes.c_float, @@ -63,6 +99,28 @@ class MatrixCoord_(ctypes.Structure): } +def get_gemm_arguments_3x(epilogue_functor): + + _EpilogueOutputOpParams = epilogue_functor.epilogue_type + + class _GemmArguments(ctypes.Structure): + _fields_ = [ + ("mode", ctypes.c_int), + ("problem_size", GemmCoordBatched_), + ("ptr_A", ctypes.c_void_p), + ("stride_A", StrideBatched_), + ("ptr_B", ctypes.c_void_p), + ("stride_B", StrideBatched_), + ("ptr_C", ctypes.c_void_p), + ("stride_C", StrideBatched_), + ("ptr_D", ctypes.c_void_p), + ("stride_D", StrideBatched_), + ("epilogue", _EpilogueOutputOpParams), + ] + + return _GemmArguments, _EpilogueOutputOpParams + + def get_gemm_arguments(epilogue_functor): _EpilogueOutputOpParams = epilogue_functor.epilogue_type @@ -103,8 +161,6 @@ class _GemmArguments(ctypes.Structure): # GEMM Grouped ########################################################################################### -# include/cutlass/gemm/kernel/gemm_grouped.h - def get_gemm_grouped_arguments(epilogue_functor): _EpilogueOutputOpParams = epilogue_functor.epilogue_type @@ -131,12 +187,6 @@ class _GEMMGroupedArguments(ctypes.Structure): # Convolution2D ############################################################################################ - -# We use the arguments as the interface - - -# include/cutlass/conv/conv2d_problem_size.h -# 64B class Conv2DProblemSize(ctypes.Structure): _fields_ = [ ("N", ctypes.c_int), @@ -164,8 +214,6 @@ def __init__(self, problem_size) -> None: setattr(self, field_name, getattr(problem_size, field_name)) -# include/cutlass/layout/tensor.h -# 12B class Layout4D(ctypes.Structure): _fields_ = [ ("stride", ctypes.c_int * 3) @@ -175,13 +223,7 @@ def __init__(self, tensor_ref): stride = tensor_ref.stride() setattr(self, "stride", (stride.at(0), stride.at(1), stride.at(2))) -# TODO: Tensor 5-D takes ("stride", ctypes.c_int * 4) - -# include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h -# TensorRef is basically cutlass::TensorRef; -# include/cutlass/tensor_ref.h -# 24B class TensorRef_(ctypes.Structure): _fields_ = [ ("ptr", ctypes.c_void_p), @@ -200,9 +242,6 @@ class TensorRef2D_(ctypes.Structure): ] -# include/cutlass/conv/kernel/implicit_gemm_convolution.h -# split_k_mode: kNone: 0, kSerial: 1, kParallel: 2, kParallelSerial: 3, kInvalid: 4 - def get_conv2d_arguments(epilogue_functor): _EpilogueOutputOpParams = epilogue_functor.epilogue_type @@ -224,7 +263,6 @@ class _Conv2dArguments(ctypes.Structure): # Reduction ############################################################################################ - def get_reduction_params(epilogue_functor): _EpilogueOutputParams = epilogue_functor.epilogue_type diff --git a/tools/library/scripts/pycutlass/src/pycutlass/compiler.py b/tools/library/scripts/pycutlass/src/pycutlass/compiler.py index cbcea67d5f..7671139132 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/compiler.py @@ -29,6 +29,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# +import pycutlass from pycutlass import * import cutlass from cuda import cuda @@ -54,11 +55,11 @@ class CompilationOptions: ''' # - def __init__(self, flags, architectures=[80], include_paths=[]): + def __init__(self, flags, arch, include_paths=[]): self.includes = [] self.include_paths = include_paths self.flags = flags - self.architectures = architectures + self.arch = arch def get_str(self): options = "" @@ -69,13 +70,11 @@ def get_str(self): for incl in self.include_paths: options += ' --include-path=%s' % incl - arch_list = "-arch=" - for idx, arch in enumerate(self.architectures): - if idx: - arch_list += "," - arch_list += "sm_%d" % arch + arch_flag = " -arch=sm_%d" % self.arch + if self.arch == 90: + arch_flag += 'a' + options += arch_flag - options += " " + arch_list return options # @@ -88,13 +87,11 @@ def get(self): for incl in self.include_paths: options.append(bytes(str.encode('--include-path=%s' % incl))) - arch_list = "-arch=" - for idx, arch in enumerate(self.architectures): - if idx: - arch_list += "," - arch_list += "sm_%d" % arch + arch_flag = " -arch=sm_%d" % self.arch + if self.arch == 90: + arch_flag += 'a' - options.append(bytes(str.encode(arch_list))) + options.append(bytes(str.encode(arch_flag))) return options @@ -138,12 +135,12 @@ def __init__(self) -> None: def nvrtc(self): self.backend = "nvrtc" self.default_compile_options = [ - '-std=c++11', '-default-device', + '-std=c++17', '-default-device' ] def nvcc(self): self.backend = "nvcc" self.default_compile_options = [ - '-std=c++11', + '-std=c++17', '--expt-relaxed-constexpr', '-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored' ] def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs): connection = sqlite3.connect("./compiled_cache.db") @@ -158,7 +155,7 @@ def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs): connection.commit() cursor.close() - def load_operation(self, op_key): + def load_operation(self, op_key, extra_funcs): connection = sqlite3.connect("./compiled_cache.db") cursor = connection.cursor() sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?""" @@ -194,12 +191,17 @@ def load_operation(self, op_key): if isinstance(attr, str): func_name = operation_name + '_' + attr func = getattr(host_lib, func_name) + + # Set the return type of the function + if attr in extra_funcs and extra_funcs[attr] != None: + func.restype = extra_funcs[attr] + compiled_host_fns[attr] = func self.compiled_cache_host.insert(key, compiled_host_fns) return True - def emit_compile_(self, operation_list, compilation_options): + def emit_compile_(self, operation_list, compilation_options, requires_nvcc_hostlib_compilation): """ Compile a list of kernels and store them into database """ @@ -276,6 +278,7 @@ def emit_compile_(self, operation_list, compilation_options): err, = nvrtc.nvrtcGetCUBIN(program, cubin_image) if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: raise RuntimeError('NVRTC Error: {}'.format(err)) + else: # with nvcc backend # emit code tempfile.tempdir = "./" @@ -303,22 +306,34 @@ def emit_compile_(self, operation_list, compilation_options): with open(temp_cubin.name, 'rb') as file: cubin_image = file.read() - # compile the host code - options = compilation_options.get() - cmd = "echo '%s'|g++ -x c++ -fpermissive -w -fPIC" % source_buffer_host - for opt in options: - opt = opt.decode("utf-8") - if opt not in ['-default-device', '-std=c++11', '-Xcicc', '-Xllc'] and '-arch=sm_' not in opt: - if '--include-path=' in opt: - cmd += " " + opt.replace('--include-path=', '-I') - else: - cmd += " " + opt + # Set up the host-side library code + if requires_nvcc_hostlib_compilation: + cuda_install_path = os.getenv('CUDA_INSTALL_PATH') + assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined." + cmd_template = "echo '%s'|${cuda_install_path}/bin/nvcc -x cu -Xcompiler=\"-fpermissive -w -fPIC\" ${options}" % source_buffer_host + cmd = SubstituteTemplate( + cmd_template, + { + "cuda_install_path": cuda_install_path, + "options": compilation_options.get_str() + }) + else: + options = compilation_options.get() + cmd = "echo '%s'|g++ -x c++ -fpermissive -w -fPIC" % source_buffer_host + filtered_opts = ['-default-device', '-Xcicc', '-Xllc', '--expt-relaxed-constexpr', '-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored'] + for opt in options: + opt = opt.decode("utf-8") + if opt not in filtered_opts and '-arch=sm_' not in opt: + if '--include-path=' in opt: + cmd += " " + opt.replace('--include-path=', '-I') + else: + cmd += " " + opt tempfile.tempdir = "./" temp = tempfile.NamedTemporaryFile( prefix='host_func', suffix='.so', delete=True) - cmd += ' - -shared -o %s' % temp.name + cmd += ' - -shared -o %s -lcudart -lcuda' % temp.name os.system(cmd) host_lib = ctypes.CDLL(temp.name) @@ -333,23 +348,25 @@ def add_module(self, operations, compile_options=None): assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined." cuda_install_path = os.getenv('CUDA_INSTALL_PATH') assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined." - architectures = [] - for operation in operations: - if hasattr(operation, "tile_description"): - cc = operation.arch - if cc not in architectures: - architectures.append(cc) include_paths = [ cuda_install_path + '/include', cutlass_path + '/include', cutlass_path + '/tools/util/include', cutlass_path + '/tools/library/scripts/pycutlass/src/cpp/include' ] + + if pycutlass.DEVICE_CC is not None: + arch = pycutlass.DEVICE_CC + else: + # Find the maximum arch tag among the provided operations and compile for that target. + # Since we are compiling to .cubin files, only one architecture may be specified. + arch = max([op.arch for op in operations]) compile_options = CompilationOptions( - self.default_compile_options, architectures, include_paths) + self.default_compile_options, arch, include_paths) # save the cubin operation_key = [] operation_list = [] + requires_nvcc_hostlib_compilation = False for operation in operations: # step 1: get kernel string as key key = operation.rt_module.emit() + operation.procedural_name() + self.backend @@ -357,7 +374,7 @@ def add_module(self, operations, compile_options=None): compiled_kernel = self.compiled_cache_device.at(key) if compiled_kernel is None: - hit = self.load_operation(key) + hit = self.load_operation(key, getattr(operation.rt_module, 'extra_funcs', {})) if hit: compiled_kernel = self.compiled_cache_device.at(key) assert compiled_kernel is not None @@ -371,9 +388,18 @@ def add_module(self, operations, compile_options=None): else: operation_list.append(operation.rt_module) operation_key.append(key) + + # Creating the Params structures for certain 3.0 kernels currently requires CUDA. For these cases, use NVCC to generate + # the PyCUTLASS host-side library. Otherwise, g++ will be used. + if isinstance(operation, pycutlass.gemm_operation.GemmOperationUniversal) and operation.api == pycutlass.library.ApiVersion.v3x: + if self.backend == "nvrtc": + raise RuntimeError('CUTLASS 3 kernels currently require NVCC for compilation.') + + requires_nvcc_hostlib_compilation = True + if len(operation_list) > 0: cubin_image, host_lib, host_file = self.emit_compile_( - operation_list, compile_options) + operation_list, compile_options, requires_nvcc_hostlib_compilation) err, module = cuda.cuModuleLoadData(cubin_image) if err != cuda.CUresult.CUDA_SUCCESS: @@ -417,9 +443,11 @@ def add_module(self, operations, compile_options=None): op_attr.append(param_size) if hasattr(operation, "extra_funcs"): - for suffix in operation.extra_funcs: + for suffix, ret_type in operation.extra_funcs.items(): func_name = operation.name() + '_' + suffix func = getattr(host_lib, func_name) + if ret_type is not None: + func.restype = ret_type setattr(operation, suffix, func) compiled_host_fns[suffix] = func op_attr.append(suffix) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py b/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py index 562fb4ef0b..0c4713cd66 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py @@ -463,13 +463,14 @@ def configuration_name(self): ) if self.stride_support == StrideSupport.Unity: - configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}" + configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}" else: - configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}" + configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}" return SubstituteTemplate( configuration_name, { + 'arch': str(self.arch), 'opcode_class': opcode_class_name, 'extended_name': self.extended_name(), 'threadblock': threadblock, @@ -509,7 +510,7 @@ def core_name(self): intermediate_type = '' if self.tile_description.math_instruction.opcode_class == cutlass.OpClass.TensorOp: - inst_shape = "%d%d%d" % tuple( + inst_shape = "%dx%dx%d" % tuple( self.tile_description.math_instruction.instruction_shape) if self.tile_description.math_instruction.element_a != self.A.element and \ self.tile_description.math_instruction.element_a != self.accumulator_type(): diff --git a/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py b/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py index 88ee07eeb9..de6d53911e 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py @@ -111,6 +111,7 @@ def __init__( self.element_output = element_output self.element_accumulator = element_accumulator self.element_epilogue = element_epilogue + self.epilogue_vector_length = epilogue_vector_length self.template_arguments = [ DataTypeTag[element_output], str(epilogue_vector_length), diff --git a/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py b/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py index 5246dd5914..bf59e43a89 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py @@ -36,6 +36,7 @@ from typeguard import typechecked import cutlass from pycutlass import * +import pycutlass.builder.collective_op_builder as collective_op_builder from cuda import cuda @@ -56,9 +57,9 @@ def transpose_layout(layout: cutlass.layout): # @typechecked -class GemmArguments(ArgumentBase): +class GemmArguments2x(ArgumentBase): """ - Argument wrapper for GEMM. It encodes problem information and + Argument wrapper for GEMM in CUTLASS 2. It encodes problem information and user-provide tensors into the kernel's argument :param operation: the GEMM operation to take the argument @@ -148,7 +149,7 @@ def __init__( self.batch_count = 1 self.split_k_slices = self.batch_count - if gemm_mode in [cutlass.gemm.Mode.Batched, cutlass.gemm.Mode.Array]: + if gemm_mode in [cutlass.gemm.Mode.Batched, cutlass.gemm.Mode.Array]: if 'batch' in kwargs.keys(): self.batch_count = kwargs['batch'] else: @@ -313,6 +314,154 @@ def initialize(self): self.device_workspace = device_workspace self.launch_config = launch_config +class GemmArguments3x(GemmArguments2x): + """ + Argument wrapper for GEMM in CUTLASS 3. It encodes problem information and + user-provide tensors into the kernel's argument + + :param operation: the GEMM operation to take the argument + :type operation: :class:`pycutlass.GemmOperationUniversal` | + :class:`pycutlass.GemmOperationGrouped` + + :param problem_size: GEMM problem size gemm(M, N, K) + :type operation: :class:`cutlass.gemm.GemmCoord` + + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param gemm_mode: GEMM mode + :type gemm_mode: :class:`cutlass.gemm.Mode` + + :param output_op: output operator, optional + :type output_op: :class:`pycutlass.LinearCombinationFunctorArguments` + """ + + def __init__( + self, operation: 'GemmOperation', problem_size: 'cutlass.gemm.GemmCoord', + A: 'Tensor', B: 'Tensor', C: 'Tensor', D: 'Tensor', + gemm_mode: 'cutlass.gemm.Mode'=cutlass.gemm.Mode.Gemm, **kwargs): + if gemm_mode not in [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.Batched]: + raise Exception("Unsupporged GEMM mode {}.".format(gemm_mode)) + + super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) + + def get_arguments(self): + problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count) + + if self.batch_count > 1: + bsA = self.batched_stride_A + bsB = self.batched_stride_B + bsC = self.batched_stride_C + bsD = self.batched_stride_D + else: + bsA = 0 + bsB = 0 + bsC = 0 + bsD = 0 + stride_A = StrideBatched_(self.lda, bsA) + stride_B = StrideBatched_(self.ldb, bsB) + stride_C = StrideBatched_(self.ldc, bsC) + stride_D = StrideBatched_(self.ldd, bsD) + + self.arguments = self.operation.argument_type( + self.gemm_mode, + problem_size_, + int(self.ptr_A), + stride_A, + int(self.ptr_B), + stride_B, + int(self.ptr_C), + stride_C, + int(self.ptr_D), + stride_D, + self.output_op, + ) + + def initialize(self): + # get the host and evice workspace + device_workspace_size = \ + self.operation.rt_module.get_device_workspace_size(self) + + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + device_workspace = 0 + if (workspace_ptr is not None and + self.gemm_mode == cutlass.gemm.Mode.GemmSplitKParallel): + # in GEMM splik-K parallel, the D pointer is redirected + # to the workspace + self.ptr_D = cuda.CUdeviceptr(workspace_ptr) + elif (workspace_ptr is not None and + self.gemm_mode == cutlass.gemm.Mode.Gemm): + # in GEMM split-K serial + device_workspace = workspace_ptr + + self.get_arguments() + res_arg = self.operation.rt_module.get_args( + ctypes.byref(self.arguments), ctypes.c_void_p(int(device_workspace))) + host_workspace = bytearray(res_arg.contents) + + grid = self.operation.rt_module.get_grid_shape( + ctypes.byref(self.arguments), ctypes.c_void_p(int(device_workspace))) + block = self.operation.rt_module.get_block_shape() + + device_workspace = None + + self.host_workspace = host_workspace + self.device_workspace = device_workspace + self.launch_config = LaunchConfiguration([grid.x, grid.y, grid.z], + [block.x, block.y, block.z], + self.operation.rt_module.shared_memory_capacity) + +def GemmArguments(operation: 'GemmOperation', problem_size: 'cutlass.gemm.GemmCoord', + A: 'Tensor', B: 'Tensor', C: 'Tensor', D: 'Tensor', + gemm_mode: 'cutlass.gemm.Mode'=cutlass.gemm.Mode.Gemm, **kwargs): + """ + Argument wrapper for GEMM in CUTLASS 2 or 3. It returns either 2x arguments + or 3x arguments depending on the `arch` field specified in `operation`. + + :param operation: the GEMM operation to take the argument + :type operation: :class:`pycutlass.GemmOperationUniversal` | + :class:`pycutlass.GemmOperationGrouped` + + :param problem_size: GEMM problem size gemm(M, N, K) + :type operation: :class:`cutlass.gemm.GemmCoord` + + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param gemm_mode: GEMM mode + :type gemm_mode: :class:`cutlass.gemm.Mode` + + :param output_op: output operator, optional + :type output_op: :class:`pycutlass.LinearCombinationFunctorArguments` + """ + ArgClass = GemmArguments3x if operation.api == ApiVersion.v3x else GemmArguments2x + return ArgClass(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) + class GemmGroupedArguments: """ @@ -383,7 +532,7 @@ def __init__( # process the input arguments for idx, problem_size in enumerate(problem_sizes): M, N, K = problem_size.m(), problem_size.n(), problem_size.k() - temp_argument = GemmArguments( + temp_argument = GemmArguments2x( operation=operation, problem_size=cutlass.gemm.GemmCoord(M, N, K), A=A[idx], B=B[idx], C=C[idx], D=D[idx], @@ -657,16 +806,164 @@ def get_device_workspace_size(self, arguments: GemmArguments): # workspace_bytes = 4 * arguments.grid_tiled_shape.x * arguments.grid_tiled_shape.y - # TODO: get extra workspace size - # see https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/gemm/device/gemm_universal_base.h return workspace_bytes +################################################################################ +# Runtime module for GEMM Universal within CUTLASS 3 +################################################################################ + +class GemmRTUniversal3x(GemmRTUniversal): + """ + GemmRTUniversal manages the CUTLASS runtime components + """ + KernelTemplate = r''' + +using Operator = ${operation_name}${operation_suffix}; +extern "C" +__global__ __launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) +void ${operation_name}(__grid_constant__ typename Operator::Params const params) { + // Dynamic shared memory base pointer + extern __shared__ char smem[]; + + // Declare pointer to dynamic shared memory. + Operator op; + op(params, smem); +} + ''' + HostTemplate = r''' +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return ${operation_name}${operation_suffix}::SharedStorageSize; + } + + using GemmType = ${operation_name}_base; + + // Get the params as byte array + char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace){ + GemmType::Params params = GemmType::to_underlying_arguments(*argument, workspace); + + char *bytes = ((char*)(¶ms)); + char *output = new char[sizeof(GemmType::Params)]; + for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++) + output[i] = bytes[i]; + + return output; + } + + // Get the grid shape + dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int* workspace) { + auto tmp_params = GemmType::to_underlying_arguments(*args, workspace); + return GemmType::get_grid_shape(tmp_params); + } + + // Get the block shape + dim3 ${operation_name}_get_block_shape() { + return GemmType::get_block_shape(); + } +} + ''' + + def __init__(self, operation: 'GemmOperation'): + super(GemmRTUniversal3x, self).__init__(operation) + self.extra_funcs = { + 'get_grid_shape': dim3_, + 'get_block_shape': dim3_ + } + self.emitter = EmitGemmUniversalInstance3x('_type') + self.argument_type, self.epilogue_type = get_gemm_arguments_3x(operation.epilogue_functor) + + +class EmitGemmUniversalInstance3x: + ''' Responsible for emitting a CUTLASS 3 template definition''' + + def __init__(self, operation_suffix=''): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cute/tensor.hpp", + "cute/atom/mma_atom.hpp", + "cutlass/numeric_types.h", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/epilogue/collective/default_epilogue.hpp", + "cutlass/epilogue/thread/linear_combination.h" + ] + self.gemm_template = """ +using namespace cute; + +${collective_op} + +using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t<${layout_c}>, + cutlass::gemm::TagToStrideC_t<${layout_c}>, + ${epilogue_functor} + >; + +// Gemm operator ${operation_name} +using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp +>; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + # + def emit(self, operation): + + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + + # Support built-in epilogue functors or user-defined functions + epilogue_functor = operation.epilogue_functor.emit() + + collective_op = collective_op_builder.build(operation) + + values = { + 'operation_name': operation.procedural_name(), + 'operation_suffix': self.operation_suffix, + 'collective_op': collective_op, + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[instance_layout_A], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[instance_layout_B], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[instance_layout_C], + 'epilogue_functor': epilogue_functor, + 'element_output': DataTypeTag[operation.epilogue_functor.element_output], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'element_epilogue': DataTypeTag[operation.epilogue_functor.element_epilogue], + 'epilogue_vector_length': str(operation.epilogue_functor.epilogue_vector_length), + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'cluster_shape_m': str(operation.tile_description.cluster_shape[0]), + 'cluster_shape_n': str(operation.tile_description.cluster_shape[1]), + 'cluster_shape_k': str(operation.tile_description.cluster_shape[2]), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment) + } + + values['epilogue_functor'] = operation.epilogue_functor.emit() + return SubstituteTemplate(self.gemm_template, values) + + ################################################################################################### # Runtime module for GEMM Grouped ################################################################################################### - class GemmRTGrouped(GemmRTbase): """ GemmRTGrouped manages the CUTLASS runtime components @@ -713,7 +1010,7 @@ class GemmRTGrouped(GemmRTbase): def __init__(self, operation: 'GemmOperation'): super(GemmRTGrouped, self).__init__(operation) - self.extra_funcs = ['precompute'] + self.extra_funcs = {'precompute': None} self.emitter = EmitGemmGroupedInstance('_type') self.argument_type, self.epilogue_type = get_gemm_grouped_arguments(operation.epilogue_functor) @@ -761,7 +1058,7 @@ def __init__( self, gemm_kind, arch, tile_description: TileDescription, A: TensorDescription, B: TensorDescription, C: TensorDescription, epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): + swizzling_functor=cutlass.IdentitySwizzle1, api=False, **kwargs): #: operation kind self.operation_kind: OperationKind = OperationKind.Gemm @@ -772,8 +1069,11 @@ def __init__( #: gemm kind self.gemm_kind: GemmKind = gemm_kind + self.api = api + self.prefix = "3x" if self.api == ApiVersion.v3x else "" + # use deep copy to avoid overwritting the original TensorDescription - if C.layout == cutlass.ColumnMajor: + if self.api != ApiVersion.v3x and C.layout == cutlass.ColumnMajor: #: Operand A self.A: TensorDescription = copy.deepcopy(B) #: Operand B @@ -800,7 +1100,6 @@ def __init__( self.direct_store = kwargs["direct_store"] else: self.direct_store = False - if "visitor" in kwargs: self.visitor = kwargs["visitor"] else: @@ -872,8 +1171,11 @@ def core_name(self): math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys( ) else '' - inst_shape = "%d%d%d" % tuple( - self.tile_description.math_instruction.instruction_shape) + if self.tile_description.math_instruction.instruction_shape is not None: + inst_shape = "%dx%dx%d" % tuple( + self.tile_description.math_instruction.instruction_shape) + else: + inst_shape = "Default" inst_shape += math_op_string if self.tile_description.math_instruction.element_a != self.A.element and \ @@ -905,6 +1207,17 @@ def extended_name(self): return extended_name + # + def extended_name_3x(self): + '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' + extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}".format( + element_a = DataTypeNames[self.A.element], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator], + element_c = DataTypeNames[self.C.element], + core_name = self.core_name()) + return extended_name + # def layout_name(self): if self.is_complex() or self.is_planar_complex(): @@ -916,25 +1229,49 @@ def layout_name(self): ) return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + # Generates a short string representing the ABC layout tags (e.g. ntn or tnn) + def layout_name_3x(self): + if self.is_complex() or self.is_planar_complex(): + return "{}{}{}".format( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)], + ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)]) + else: + return "{}{}{}".format( + ShortLayoutTypeNames[self.A.layout], + ShortLayoutTypeNames[self.B.layout], + ShortLayoutTypeNames[self.C.layout]) + # def procedural_name(self): ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - threadblock = self.tile_description.procedural_name() - opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - - alignment = max([self.A.alignment, self.B.alignment, self.C.alignment]) - - return SubstituteTemplate( - "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}", - { - 'opcode_class': opcode_class_name, - 'extended_name': self.extended_name(), - 'threadblock': threadblock, - 'layout': self.layout_name(), - 'alignment': "%d" % self.A.alignment, - } - ) + if self.api == ApiVersion.v3x and self.arch >= 90: + kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}" + return kernel_name_template.format( + p = self.prefix, + ar = self.arch, + op = opcode_class_name, + ex = self.extended_name_3x(), + tbm = self.tile_description.threadblock_shape[0], + tbn = self.tile_description.threadblock_shape[1], + tbk = self.tile_description.threadblock_shape[2], + cm = self.tile_description.cluster_shape[0], + cn = self.tile_description.cluster_shape[1], + ck = self.tile_description.cluster_shape[2], + l = self.tile_description.stages, + s = self.layout_name_3x(), + al = str(self.A.alignment)) + else: + threadblock = self.tile_description.procedural_name() + return "cutlass{p}_sm{ar}_{op}_{ex}_{tb}_{l}_align{a}".format( + p = self.prefix, + ar = self.arch, + op = opcode_class_name, + ex = self.extended_name(), + tb = threadblock, + l = self.layout_name(), + a = str(self.A.alignment)) # def configuration_name(self): @@ -945,9 +1282,14 @@ def configuration_name(self): class GemmOperationUniversal(GemmOperationBase): def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): + api = api_version(arch, tile_description.math_instruction.opcode_class, A.element) super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description, - A, B, C, epilogue_functor, swizzling_functor, **kwargs) - self.rt_module = GemmRTUniversal(self) + A, B, C, epilogue_functor, swizzling_functor, + api=api, **kwargs) + if api == ApiVersion.v3x: + self.rt_module = GemmRTUniversal3x(self) + else: + self.rt_module = GemmRTUniversal(self) self.argument_type = self.rt_module.argument_type self.epilogue_type = self.rt_module.epilogue_type diff --git a/tools/library/scripts/pycutlass/src/pycutlass/library.py b/tools/library/scripts/pycutlass/src/pycutlass/library.py index 0854624764..b18f2be278 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/library.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/library.py @@ -36,6 +36,7 @@ import enum import cutlass +import cute # The following block implements enum.auto() for Python 3.5 variants that don't include it such # as the default 3.5.2 on Ubuntu 16.04. @@ -182,6 +183,30 @@ class GeneratorTarget(enum.Enum): cutlass.dtype.cs64: 128, } + +class DataTypeSizeBytes: + """ + Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the + data type key is less than a full byte or a non-integer number of bytes. + """ + @staticmethod + def __class_getitem__(datatype): + """ + Returns the number of bytes in size the data type is. Raises an exception if the data type + is either less than a full byte or a non-integer number of bytes in size. + + :param datatype: data type to query + + :return: number of bytes the data type occupies + :rtype: int + """ + bits = DataTypeSize[datatype] + if bits < 8: + raise Exception('Data type {} is less than one byte in size.'.format(datatype)) + elif bits % 8 != 0: + raise Exception('Data type {} is not an integer number of bytes.'.format(datatype)) + return bits // 8 + ################################################################################################### # @@ -350,6 +375,12 @@ class MathOperation(enum.Enum): (cutlass.RowMajor, cutlass.complex_transform.conj): 'h' } +# +CuTeLayoutTag = { + cute.GMMAMajor.K: 'cute::GMMA::Major::K', + cute.GMMAMajor.MN: 'cute::GMMA::Major::MN' +} + ################################################################################################### # @@ -436,7 +467,6 @@ class DiagType(enum.Enum): # - class OperationKind(enum.Enum): Gemm = enum_auto() RankK = enum_auto() @@ -460,16 +490,19 @@ class OperationKind(enum.Enum): 70: 'volta', 75: 'turing', 80: 'ampere', + 90: 'hopper' } # SharedMemPerCC = { - 70: 96, # 96KB of SMEM - 72: 96, # 96KB of SMEM - 75: 64, # 64KB of SMEM - 80: 160, # 164KB of SMEM - 4KB reserved for the driver - 86: 100, # 100KB of SMEM - 87: 160, # 164KB of SMEM - 4KB reserved for the driver + 70: 96 << 10, # 96KB of SMEM + 72: 96 << 10, # 96KB of SMEM + 75: 64 << 10, # 64KB of SMEM + 80: 160 << 10, # 164KB of SMEM - 4KB reserved for the driver + 86: 100 << 10, # 100KB of SMEM + 87: 160 << 10, # 164KB of SMEM - 4KB reserved for the driver + 89: 100 << 10, # 100KB of SMEM + 90: 227 << 10, # 228KB of SMEM - 1KB reserved for the driver } ################################################################################################### @@ -646,7 +679,21 @@ class ConvMode(enum.Enum): class MathInstruction: + """ + Description of a the lowest-level matrix-multiply-accumulate operation to be used in a kernel + """ def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class=cutlass.OpClass.Simt, math_operation=MathOperation.multiply_add): + """ + :param instruction_shape: size of the [M, N, K] dimensions of the instruction + :type instruction_shape: list or tuple + :param element_a: data type of operand A + :param element_b: data type of operand B + :param element_accumulator: data type used in accumulation + :param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core) + :type opcode_class: cutlass.OpClass + :param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate) + :type math_operation: MathOperation + """ self.instruction_shape = instruction_shape self.element_a = element_a self.element_b = element_b @@ -658,24 +705,65 @@ def __init__(self, instruction_shape, element_a, element_b, element_accumulator, class TileDescription: - - def __init__(self, threadblock_shape, stages, warp_count, math_instruction): + """ + Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes, + stage count, and math instruction specification + """ + def __init__(self, threadblock_shape, stages, warp_count, math_instruction, cluster_shape=[1, 1, 1], persistent=False): + """ + :param threadblock_shape: shape of a threadblock tyle + :type threadblock_shape: list or tuple + :param stages: number of pipline stages in the operation. For SM90 kernels, this can be set to `None` and the maximum + number of stages that can be supported for an operation on a given architecture will be computed at a later time + :type stages: int or None + :param warp_count: number of warps in each [M, N, K] dimension of a threadblock tile + :type warp_count: list, tuple, or None + :param math_instruction: specification of the instruction type and shape to be performed and the types of its operands + :type math_instruction: MathInstruction + :param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster + :param persistent: whether the kernel uses persistent warp-specialized threadblocks (only available for SM90+) + :type persistent: bool + """ self.threadblock_shape = threadblock_shape - - #: number of pipeline stages + self.cluster_shape = cluster_shape + self.persistent: bool = persistent self.stages: int = stages - #: number of warps along x, y, z directions - self.warp_count: list[int] = warp_count self.math_instruction = math_instruction - #: number threads per threadblock - self.num_threads: int = 32 - for cnt in self.warp_count: - self.num_threads *= cnt + # Number of warps along x, y, z directions + self.warp_count = warp_count + + @property + def num_threads(self): + """ + Returns the number of threads in the threadblock + + :return: number of threads in the threadblock + :rtype: int or None (if warp count is None) + """ + if self.warp_count is not None: + threads = 32 + for cnt in self.warp_count: + threads *= cnt + return threads + return None def procedural_name(self): - return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) + """ + Returns a name identifying the tile description + + :return: name identifying the tile description + :rtype: int + """ + emit_stages = 0 if self.stages is None else self.stages + name = "%dx%dx%d_%dx%d_%dx%d" % ( + self.cluster_shape[0], self.cluster_shape[1], self.cluster_shape[2], + self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], emit_stages) + + if self.persistent: + name += '_persistent' + return name # @@ -715,30 +803,68 @@ def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment=1 ################################################################################################### # +def CalculateSmemUsagePerStage(operation): + """ + Returns the amount of shared memory in bytes consumed in a single stage of a kernel. + + :param op: operation for which the maximum stages should be computed. If stages are + set via the `op.tile_description.stages` parameter, this setting is ignored + in the present calculation + :type op: pycutlass.Operation + :return: number of bytes of shared memory consumed by a single stage + :rtype: int + """ + m, n, k = operation.tile_description.threadblock_shape + if operation.operation_kind == OperationKind.Gemm: + stage_barrier_bytes = 32 + return (DataTypeSize[operation.A.element] * m * k // 8) + \ + (DataTypeSize[operation.B.element] * k * n // 8) + stage_barrier_bytes + else: + raise Exception('Unsupported operation kind {}.'.format(operation.operation_kind)) + + +# def CalculateSmemUsage(operation): - cta_shape = operation.tile_description.threadblock_shape - stages = operation.tile_description.stages - - if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse: - # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity) - if DataTypeSize[operation.A.element] == 32: - elements_per_8b_md = 2 - elif DataTypeSize[operation.A.element] == 4: - elements_per_8b_md = 8 - else: - elements_per_8b_md = 4 - - smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \ - DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \ - cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md + """ + Returns the amount of shared memory in bytes consumed by a kernel. + + :param op: operation for which the maximum stages should be computed. If stages are + set via the `op.tile_description.stages` parameter, this setting is ignored + in the present calculation + :type op: pycutlass.Operation + + :return: int + """ + return operation.tile_description.stages * CalculateSmemUsagePerStage(operation) + + +class ApiVersion(enum.Enum): + """ + Differentiate between CUTLASS 2.x and 3.x API versions + """ + v2x = enum_auto() + v3x = enum_auto() + + +def api_version(arch, opclass, datatype): + """ + Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x + or 3.x for code emission. + + :param arch: compute capability of device on which to run + :type arch: int + :param opclass: class of the operation being performed + :type opclass: cutlass.OpClass + :param datatype: data type to be used in operation (assumes that ElementA and ElementB are the same) + + :return: API version to be used in code emission + :rtype: ApiVersion + """ + if arch >= 90 and opclass == cutlass.OpClass.TensorOp and (datatype != cutlass.float64): + return ApiVersion.v3x else: - # Few BLAS3 operations only have A tensor - smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * cta_shape[2] // 8 + \ - DataTypeSize[operation.A.element] * \ - cta_shape[1] * cta_shape[2] // 8 + return ApiVersion.v2x - smem_usage = smem_per_stage * stages - return (smem_usage >> 10) ################################################################################################### diff --git a/tools/library/scripts/pycutlass/src/pycutlass/operation.py b/tools/library/scripts/pycutlass/src/pycutlass/operation.py index e35952099e..9184e514c9 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/operation.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/operation.py @@ -32,6 +32,12 @@ import ctypes from cuda import cuda +from pycutlass.utils.device import device_cc + +from cuda import __version__ as __cuda_version__ +_version_splits = [int(x) for x in __cuda_version__.split('.')] +supports_cluster_launch = device_cc() >= 90 and (_version_splits[0] > 11 or (_version_splits[0] == 11 and _version_splits[1] >= 8)) + ################################################################################ # @@ -90,21 +96,58 @@ def plan(self, arguments): def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=cuda.CUstream(0)): raise NotImplementedError() + # - def run(self, host_workspace, device_workspace, launch_config, stream=cuda.CUstream(0)): + def run_with_clusters(self, launch_config, kernel_params, stream=cuda.CUstream(0)): + if hasattr(self.operation, 'tile_description') and hasattr(self.operation.tile_description, 'cluster_shape'): + attr = cuda.CUlaunchAttribute() + attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape + attr.id = cuda.CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + attrs = [attr] + + # Allow for non-portable cluster sizes + err, = cuda.cuFuncSetAttribute( + self.kernel, cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1) + if err != cuda.CUresult.CUDA_SUCCESS: + return err + else: + attrs = [] + + config = cuda.CUlaunchConfig() + config.gridDimX, config.gridDimY, config.gridDimZ = launch_config.grid + config.blockDimX, config.blockDimY, config.blockDimZ = launch_config.block + config.blockDimZ = launch_config.block[2] + config.sharedMemBytes = launch_config.shared_memory_capacity + config.hStream = stream + config.attrs = attrs + config.numAttrs = len(attrs) + + err, = cuda.cuLaunchKernelEx(config, f=self.kernel, kernelParams=kernel_params, extra=0) + return err - cArg = (ctypes.c_char * len(host_workspace) - ).from_buffer(host_workspace) - packed = (ctypes.c_void_p * 1)() - packed[0] = ctypes.addressof(cArg) + # + def run_without_clusters(self, launch_config, kernel_params, stream=cuda.CUstream(0)): err, = cuda.cuLaunchKernel( self.kernel, launch_config.grid[0], launch_config.grid[1], launch_config.grid[2], launch_config.block[0], launch_config.block[1], launch_config.block[2], launch_config.shared_memory_capacity, stream, - packed, + kernel_params, 0) return err + + + # + def run(self, host_workspace, device_workspace, launch_config, stream=cuda.CUstream(0)): + cArg = (ctypes.c_char * len(host_workspace) + ).from_buffer(host_workspace) + packed = (ctypes.c_void_p * 1)() + packed[0] = ctypes.addressof(cArg) + + if supports_cluster_launch: + return self.run_with_clusters(launch_config, packed, stream) + else: + return self.run_without_clusters(launch_config, packed, stream) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/parser.py b/tools/library/scripts/pycutlass/src/pycutlass/parser.py index 551638c46d..6eb02bfbfb 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/parser.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/parser.py @@ -543,7 +543,6 @@ def __init__(self, elementwise_functor, tile_description, self.elements_per_access = elements_per_access self.element_compute = element_compute self.element_output = element_output - # TODO: deprecate this self.elementwise_functor = elementwise_functor pass @@ -554,11 +553,8 @@ def initialize(self): # tree = function.epilogue_tree self.tree = tree - # self.tree.show() # for debug function.pass_binary_2_unary(self.tree, self.tree.root) - # self.tree.show() # for debug function.pass_inject_reduction(self.tree, self.tree.root) - # self.tree.show() # for debug function.pass_inject_epilogue_op(self.tree,self.tree.root) visitor = self.tree.get_node(self.tree.root).data.epilogue_node @@ -575,7 +571,6 @@ def __init__(self, **kwargs) -> None: if input_key == "accum": continue if function.input_args[input_key][0] == "scalar": - # _kwargs[input_key] = kwargs[input_key] continue # tensor input else: diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py b/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py index 8192b0808e..63ae6da94e 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py @@ -265,15 +265,6 @@ def flops(self, problem_size): flops_total_ = flops_mainloop_ + flops_epilogue_ - # TODO complex-value support - # switch (operation_desc.tile_description.math_instruction.math_operation) { - # case library::MathOperationID::kMultiplyAddComplex: - # flops_total_ *=4; - # break; - - # default: break; - # } - return flops_total_ @@ -511,9 +502,8 @@ def run(self, problem_size, split_k_mode=cutlass.conv.SplitKMode.Serial, # (conv_blacklist_sizes) ############################################################################################################ -def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes = [], interleaved=False): # TODO: conv_test_sizes and conv_blacklist_sizes +def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes = [], interleaved=False): passed = True - # # Testbed object # @@ -529,8 +519,6 @@ def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes = [], interleave # Vector of conv2d problem sizes to avoid duplicate runs conv_tested_sizes = [] - # TODO: include resnet 50 sizes, user sepecified sizes, and rigorous sizes - # Flatten 2D problem_vectors into a 1D problem sizes problem_sizes = conv_problems.conv2d_default_sizes @@ -539,7 +527,6 @@ def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes = [], interleave # Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slices=1, alpha=1.0, beta=0.0) for conv_problem in problem_sizes: - # TODO: skip blacklist problem sizes if conv_problem in conv_tested_sizes: continue @@ -585,9 +572,8 @@ def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes = [], interleave passed = testbed.run(conv_problem) - # if not passed: return False - - # TODO: If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts + if not passed: + return False if interleaved: return True diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py b/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py index f362395ece..6cf14f32a2 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py @@ -184,7 +184,7 @@ def run(self, problem_count: int, alpha: float = 1.0, beta: float = 0.0) -> bool arguments.sync() # - # Reference check - TODO: support caching results + # Reference check # alpha = self.compute_type(alpha).value() beta = self.compute_type(beta).value() diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py b/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py index bfdaf0c2be..4fb46c1fa2 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py @@ -33,6 +33,7 @@ from time import sleep import pycutlass from pycutlass import * +import pycutlass.utils.datatypes as datatypes import cutlass from cuda import cudart from cuda import cuda @@ -52,16 +53,22 @@ def transpose(layout): return cutlass.ColumnMajorInterleaved32 -def getTensorRef(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, operand: str, layout: cutlass.layout): +def getTensorRef(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, operand: str, layout: cutlass.layout, batch_offset: int = 0): ptr = tensor.__array_interface__['data'][0] if operand == "a": tensor_coord = problem_size.mk() + batch_stride = problem_size.m() * problem_size.k() elif operand == "b": tensor_coord = problem_size.kn() + batch_stride = problem_size.k() * problem_size.n() elif operand in ["c", "d"]: tensor_coord = problem_size.mn() + batch_stride = problem_size.m() * problem_size.n() else: - raise ValueError("unknonw operand: " + operand) + raise ValueError("Unknown operand: " + operand) + + elt_size = DataTypeSizeBytes[datatypes.to_cutlass(tensor.dtype)] + ptr += batch_offset * batch_stride * elt_size if layout == cutlass.RowMajor: layout = cutlass.RowMajor.packed(tensor_coord) @@ -96,8 +103,8 @@ def getTensorRef(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, opera return getattr(cutlass, ref_name)(ptr, layout) -def getTensorView(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, operand: str, layout: str): - tensor_ref = getTensorRef(tensor, problem_size, operand, layout) +def getTensorView(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, operand: str, layout: str, batch_offset: int = 0): + tensor_ref = getTensorRef(tensor, problem_size, operand, layout, batch_offset) if operand == "a": tensor_coord = problem_size.mk() @@ -106,7 +113,7 @@ def getTensorView(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, oper elif operand in ["c", "d"]: tensor_coord = problem_size.mn() else: - raise ValueError("unknonw operand: " + operand) + raise ValueError("Unknown operand: " + operand) if layout == cutlass.RowMajor: layout_tag = "RowMajor" @@ -168,7 +175,12 @@ def __init__(self, operation: 'GemmOperationUniversal', seed: int = 2080, interl # Compile the operator # - pycutlass.compiler.add_module([operation, self.reduction_operation]) + op_list = [operation] + if operation.arch < 90: + # Split K via Python is currently only supported for pre-SM90 kernels + op_list.append(self.reduction_operation) + + pycutlass.compiler.add_module(op_list) self.operation = operation @@ -206,8 +218,10 @@ def __init__(self, operation: 'GemmOperationUniversal', seed: int = 2080, interl def print_problem_size(self, p, mode, batch_count): if mode == cutlass.gemm.Mode.Gemm: mode = "Gemm" + elif mode == cutlass.gemm.Mode.Batched: + mode = "GemmBatched" elif mode == cutlass.gemm.Mode.GemmSplitKParallel: - mode = "GemmSplitKParalel" + mode = "GemmSplitKParallel" problem_size = "problem: %d, %d, %d\n batch_count: %d\n mode: %s" % ( p.m(), p.n(), p.k(), batch_count, mode) print(problem_size) @@ -251,8 +265,7 @@ def reorder_tensor_B(self, tensor_B, problem_size): tensor_ref_B, reordered_tensor_ref_B, problem_size) return reordered_tensor_B - def host_reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta): - # TODO + def host_reference(self, problem_size, batch_count, tensor_A, tensor_B, tensor_C, alpha, beta): tensor_D_ref = np.ones_like(tensor_C) alpha = self.numpy_type(self.compute_type)(alpha) beta = self.numpy_type(self.compute_type)(beta) @@ -262,42 +275,46 @@ def host_reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta beta = self.compute_type(beta).value() init_acc = self.accumulator_type(init_acc).value() - if self.operation.switched: - tensor_ref_A = getTensorRef( - tensor_A, problem_size, "a", transpose(self.operation.B.layout)) - tensor_ref_B = getTensorRef( - tensor_B, problem_size, "b", transpose(self.operation.A.layout)) - tensor_ref_C = getTensorRef( - tensor_C, problem_size, "c", transpose(self.operation.C.layout)) - tensor_ref_D_ref = getTensorRef( - tensor_D_ref, problem_size, "d", transpose(self.operation.C.layout)) - else: - tensor_ref_A = getTensorRef( - tensor_A, problem_size, "a", self.operation.A.layout) - tensor_ref_B = getTensorRef( - tensor_B, problem_size, "b", self.operation.B.layout) - tensor_ref_C = getTensorRef( - tensor_C, problem_size, "c", self.operation.C.layout) - tensor_ref_D_ref = getTensorRef( - tensor_D_ref, problem_size, "d", self.operation.C.layout) - - if self.math_operation in [MathOperation.multiply_add_saturate]: - cutlass.test.gemm.host.gemm_saturate( - problem_size, alpha, tensor_ref_A, tensor_ref_B, beta, tensor_ref_C, tensor_ref_D_ref, init_acc) - else: - cutlass.test.gemm.host.gemm(problem_size, alpha, tensor_ref_A, - tensor_ref_B, beta, tensor_ref_C, tensor_ref_D_ref, init_acc) + for i in range(batch_count): + if self.operation.switched: + tensor_ref_A = getTensorRef( + tensor_A, problem_size, "a", transpose(self.operation.B.layout), batch_offset=i) + tensor_ref_B = getTensorRef( + tensor_B, problem_size, "b", transpose(self.operation.A.layout), batch_offset=i) + tensor_ref_C = getTensorRef( + tensor_C, problem_size, "c", transpose(self.operation.C.layout), batch_offset=i) + tensor_ref_D_ref = getTensorRef( + tensor_D_ref, problem_size, "d", transpose(self.operation.C.layout), batch_offset=i) + else: + tensor_ref_A = getTensorRef( + tensor_A, problem_size, "a", self.operation.A.layout, batch_offset=i) + tensor_ref_B = getTensorRef( + tensor_B, problem_size, "b", self.operation.B.layout, batch_offset=i) + tensor_ref_C = getTensorRef( + tensor_C, problem_size, "c", self.operation.C.layout, batch_offset=i) + tensor_ref_D_ref = getTensorRef( + tensor_D_ref, problem_size, "d", self.operation.C.layout, batch_offset=i) + + if self.math_operation in [MathOperation.multiply_add_saturate]: + cutlass.test.gemm.host.gemm_saturate( + problem_size, alpha, tensor_ref_A, tensor_ref_B, beta, tensor_ref_C, tensor_ref_D_ref, init_acc) + else: + cutlass.test.gemm.host.gemm(problem_size, alpha, tensor_ref_A, + tensor_ref_B, beta, tensor_ref_C, tensor_ref_D_ref, init_acc) return tensor_D_ref - def equal(self, tensor_D, tensor_D_ref, problem_size): + def equal(self, tensor_D, tensor_D_ref, problem_size, batch_count): + for i in range(batch_count): + tensor_view_D = getTensorView( + tensor_D, problem_size, "d", self.operation.C.layout, batch_offset=i) + tensor_view_D_ref = getTensorView( + tensor_D_ref, problem_size, "d", self.operation.C.layout, batch_offset=i) - tensor_view_D = getTensorView( - tensor_D, problem_size, "d", self.operation.C.layout) - tensor_view_D_ref = getTensorView( - tensor_D_ref, problem_size, "d", self.operation.C.layout) + if not cutlass.test.gemm.host.equals(tensor_view_D, tensor_view_D_ref): + return False - return cutlass.test.gemm.host.equals(tensor_view_D, tensor_view_D_ref) + return True def bytes(self, problem_size, batch_count=1, alpha=1.0, beta=0.0): m = problem_size.m() @@ -321,9 +338,8 @@ def flops(self, problem_size, batch_count=1): n = problem_size.n() k = problem_size.k() - flops_ = (m * n * k + m * n) * 2 * batch_count + flops_ = (m * n * k) * 2 * batch_count - # TODO: complex return flops_ def run_cutlass_profiler(self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0): @@ -368,21 +384,25 @@ def run_cutlass_profiler(self, mode, problem_size, batch_count=1, alpha=1.0, bet return runtime - def run(self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0): - + def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0): assert get_allocated_size( ) == 0, "%d byte of pool memory is not released in previous run" % get_allocated_size() np.random.seed(self.seed) + # Assign an actual batch count in cases where we are not running in batched mode. + # This is to differentiate between the number of split K slices and the batch count, + # which are overloaded within the single `batch_count` variable. + true_batch_count = batch_count if mode == cutlass.gemm.Mode.Batched else 1 + tensor_A = self.uniform_init( - size=(problem_size.m() * problem_size.k(),), dtype=self.dtype_A) + size=(problem_size.m() * problem_size.k() * true_batch_count,), dtype=self.dtype_A) tensor_B = self.uniform_init( - size=(problem_size.n() * problem_size.k(),), dtype=self.dtype_B) + size=(problem_size.n() * problem_size.k() * true_batch_count,), dtype=self.dtype_B) tensor_C = self.uniform_init( - size=(problem_size.m() * problem_size.n(),), dtype=self.dtype_C) + size=(problem_size.m() * problem_size.n() * true_batch_count,), dtype=self.dtype_C) tensor_D = np.zeros( - shape=(problem_size.m() * problem_size.n(),), dtype=self.dtype_D) + shape=(problem_size.m() * problem_size.n() * true_batch_count,), dtype=self.dtype_D) # # Launch kernel @@ -392,14 +412,14 @@ def run(self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0): operation=self.operation, problem_size=problem_size, A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, output_op=self.operation.epilogue_type(alpha, beta), - gemm_mode=mode, split_k_slices=batch_count + gemm_mode=mode, split_k_slices=split_k_slices, batch=batch_count ) if mode == cutlass.gemm.Mode.GemmSplitKParallel: reduction_arguments = ReductionArguments( self.reduction_operation, problem_size=[ problem_size.m(), problem_size.n()], - partitions=batch_count, + partitions=split_k_slices, workspace=arguments.ptr_D, destination=tensor_D, source=tensor_C, @@ -419,8 +439,8 @@ def run(self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0): else: arguments.sync() tensor_D_ref = self.host_reference( - problem_size, tensor_A, tensor_B, tensor_C, alpha, beta) - passed = self.equal(tensor_D, tensor_D_ref, problem_size) + problem_size, true_batch_count, tensor_A, tensor_B, tensor_C, alpha, beta) + passed = self.equal(tensor_D, tensor_D_ref, problem_size, true_batch_count) try: assert passed @@ -494,7 +514,7 @@ def test_all_gemm(operation: 'GemmOperationUniversal', testcase="universal"): if operation.A.layout in [cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32]: interleavedk = 32 else: - raise ValueError("unknonw layout") + raise ValueError("Unknown layout") if testcase == "interleaved": modes = [cutlass.gemm.Mode.Gemm, ] @@ -515,14 +535,22 @@ def test_all_gemm(operation: 'GemmOperationUniversal', testcase="universal"): problem_beta = [0.0] batch_counts = [1, ] else: # universal - modes = [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.GemmSplitKParallel] + modes = [cutlass.gemm.Mode.Gemm] + batch_counts = [1, 2, 3, 5, 7] + if operation.arch < 90: + # Split K kernels via Python are currently only supported pre-SM90 + modes.append(cutlass.gemm.Mode.GemmSplitKParallel) + problem_size_m = [alignment_m, 512 - 3 * alignment_m] problem_size_n = [alignment_n, 512 - 2 * alignment_n] + if operation.tile_description.stages is None: + stages_for_k_calc = 7 + else: + stages_for_k_calc = operation.tile_description.stages problem_size_k = [ alignment_k, - threadblock_k * operation.tile_description.stages - alignment_k, - threadblock_k * operation.tile_description.stages * 3 - alignment_k] - batch_counts = [1, 2, 3, 5, 7] + threadblock_k * stages_for_k_calc - alignment_k, + threadblock_k * stages_for_k_calc * 3 - alignment_k] problem_alpha = [1.0] problem_beta = [2.0] @@ -543,8 +571,17 @@ def test_all_gemm(operation: 'GemmOperationUniversal', testcase="universal"): problem_size = cutlass.gemm.GemmCoord(m, n, k) + if operation.arch < 90: + split_k_slices = batch_count + else: + split_k_slices = 1 + + overridden_mode = mode + if mode == cutlass.gemm.Mode.Gemm and batch_count > 1: + overridden_mode = cutlass.gemm.Mode.Batched + passed = testbed.run( - mode, problem_size, batch_count, alpha, beta) + overridden_mode, problem_size, batch_count, split_k_slices, alpha, beta) err, = cudart.cudaDeviceSynchronize() if err != cuda.CUresult.CUDA_SUCCESS: diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/utils.py b/tools/library/scripts/pycutlass/src/pycutlass/test/utils.py new file mode 100644 index 0000000000..55281bec6d --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/test/utils.py @@ -0,0 +1,109 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import cutlass +from pycutlass import library, SubstituteTemplate + + +class Layout: + """ + Utility class to map transpose and non-transpose terminology to row- and column-major terminology + """ + T = cutlass.RowMajor + N = cutlass.ColumnMajor + + +class LayoutCombination: + """ + Utility class defining all combinations of row- and column-major layouts for operands to a GEMMs + """ + NNN = (Layout.N, Layout.N, Layout.N) + NNT = (Layout.N, Layout.N, Layout.T) + NTN = (Layout.N, Layout.T, Layout.N) + NTT = (Layout.N, Layout.T, Layout.T) + TNN = (Layout.T, Layout.N, Layout.N) + TNT = (Layout.T, Layout.N, Layout.T) + TTN = (Layout.T, Layout.T, Layout.N) + TTT = (Layout.T, Layout.T, Layout.T) + + +def get_name(layouts, alignments, element_output, + element_accumulator, element_epilogue, cluster_shape, + threadblock_shape, stages, element_a, element_b, arch, opclass, suffix=""): + """ + Generates a procedural name for a test case. + + :param layouts: indexable container of layouts of A, B, and C operands + :param alignments: indexable container of alingments of A, B, and C operands + :param element_output: data type of the output element + :param element_accumulator: data type used in accumulation + :param element_epilogue: data type used in computing the epilogue + :param cluster_shape: indexable container of dimensions of threadblock cluster to be launched + :param threadblock_shape: indexable container of dimensions of threadblock tiles + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param element_a: data type of operand A + :param element_b: data type of operand B + :param arch: compute capability of kernel being generated + :type arch: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass.OpClass + :param suffix: additional string to add to the suffix of the name + :type suffix: str + + :return: str + """ + name_format = 'test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${suffix}' + return SubstituteTemplate(name_format, + { + 'arch': str(arch), + 'eA': library.DataTypeNames[element_a], + 'eB': library.DataTypeNames[element_b], + 'eC': library.DataTypeNames[element_output], + 'lA': library.ShortLayoutTypeNames[layouts[0]], + 'lB': library.ShortLayoutTypeNames[layouts[1]], + 'lC': library.ShortLayoutTypeNames[layouts[2]], + 'opclass': library.OpcodeClassNames[opclass], + 'acc': library.DataTypeNames[element_accumulator], + 'cM': str(cluster_shape[0]), + 'cN': str(cluster_shape[1]), + 'cK': str(cluster_shape[2]), + 'tbM': str(threadblock_shape[0]), + 'tbN': str(threadblock_shape[1]), + 'tbK': str(threadblock_shape[2]), + 'stages': str(stages) if stages is not None else 'auto', + 'aA' : str(alignments[0]), + 'aB' : str(alignments[1]), + 'aC' : str(alignments[2]), + 'suffix': '' if suffix is None else suffix + } + ) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/utils/datatypes.py b/tools/library/scripts/pycutlass/src/pycutlass/utils/datatypes.py new file mode 100644 index 0000000000..f4cc56baf7 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/utils/datatypes.py @@ -0,0 +1,121 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for converting between frontend datatypes and CUTLASS datatypes +""" + +from typing import Union, Tuple + +import cutlass + +import pycutlass.library as library + + +try: + import numpy as np + numpy_available = True +except ImportError: + numpy_available = False + +def numpy_to_cutlass(inp): + if numpy_available: + if inp == np.float16: + return cutlass.float16 + elif inp == np.float32: + return cutlass.float32 + elif inp == np.float64: + return cutlass.float64 + elif inp == np.int8: + return cutlass.int8 + elif inp == np.int32: + return cutlass.int32 + return None + +try: + import cupy as cp + cupy_available = True + cupy_to_cutlass_dict = { + cp.float16: cutlass.float16, + cp.float32: cutlass.float32, + cp.float64: cutlass.float64 + } +except ImportError: + cupy_available = False + +def cupy_to_cutlass(inp): + if cupy_available: + if inp == cp.float16: + return cutlass.float16 + elif inp == cp.float32: + return cutlass.float32 + elif inp == cp.float64: + return cutlass.float64 + return None + +try: + import torch + torch_available = True + torch_to_cutlass_dict = { + torch.half: cutlass.float16, + torch.float16: cutlass.float16, + torch.float: cutlass.float32, + torch.float32: cutlass.float32, + torch.double: cutlass.float64, + torch.float64: cutlass.float64 + } +except ImportError: + torch_available = False + +def torch_to_cutlass(inp): + if torch_available: + return torch_to_cutlass_dict.get(inp, None) + +try: + import bfloat16 + bfloat16_available = True +except ImportError: + bfloat16_available = False + +def bfloat16_to_cutlass(inp): + if bfloat16_available: + if inp == bfloat16.bfloat16: + return cutlass.bfloat16 + + +def to_cutlass(inp): + for cvt_fn in [bfloat16_to_cutlass, cupy_to_cutlass, numpy_to_cutlass, torch_to_cutlass]: + out = cvt_fn(inp) + if out is not None: + return out + + raise Exception('No available conversion from type {} to a CUTLASS type.'.format(inp)) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py index 6948d27487..2f003b5046 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu from pycutlass.conv2d_operation import * from pycutlass import * diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index 26741ced8b..2813f1c79e 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu import pycutlass from pycutlass import * diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py index 821f99c7c6..93d9e3bbef 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu import pycutlass from pycutlass.conv2d_operation import * diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py index 210c2ba34f..53fb0ebc64 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu import pycutlass from pycutlass import * diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py index 54dbea9646..0e9806167a 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu import pycutlass from pycutlass.test import * diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py index 4be81f99b3..b4d9b45e7b 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu import pycutlass from pycutlass.test import * diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py index 49d59c1a53..cf772782f3 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu import pycutlass from pycutlass import * diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index 36d115e43d..8276bdd966 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu import pycutlass from pycutlass import * diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py index 578b5fd862..6949697f21 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu import pycutlass from pycutlass.conv2d_operation import * diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py index aa9f1da6e8..10520e1f32 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu import pycutlass from pycutlass import * diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index 5e4ce635d8..efa2d2d1fe 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu import pycutlass from pycutlass import * diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py index 64b40dd713..2e6828c2da 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu import pycutlass from pycutlass import * diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index 96f9ff36be..bb7533b601 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu import pycutlass from pycutlass import * diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py index a42a098086..e2a60f9a55 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu import pycutlass from pycutlass.conv2d_operation import * diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py index b64bd39f57..213618b112 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # test/unit/conv/device/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu import pycutlass from pycutlass import * diff --git a/tools/library/scripts/pycutlass/test/conv/run_all_tests.py b/tools/library/scripts/pycutlass/test/conv/run_all_tests.py index 39278be21e..9fec5d28eb 100644 --- a/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +++ b/tools/library/scripts/pycutlass/test/conv/run_all_tests.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + import pycutlass import unittest from pycutlass.memory_manager import * diff --git a/tools/library/scripts/pycutlass/test/example/run_all_example.sh b/tools/library/scripts/pycutlass/test/example/run_all_example.sh index c05eb048db..0a51ccf677 100755 --- a/tools/library/scripts/pycutlass/test/example/run_all_example.sh +++ b/tools/library/scripts/pycutlass/test/example/run_all_example.sh @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + pushd $CUTLASS_PATH/examples/40_cutlass_py/customizable python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 diff --git a/tools/library/scripts/pycutlass/test/frontend/run_test.sh b/tools/library/scripts/pycutlass/test/frontend/run_test.sh index 67aa3de57d..072f60b5c1 100644 --- a/tools/library/scripts/pycutlass/test/frontend/run_test.sh +++ b/tools/library/scripts/pycutlass/test/frontend/run_test.sh @@ -1 +1,33 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + CUPY_CACHE_DIR=./ python test_frontend.py diff --git a/tools/library/scripts/pycutlass/test/frontend/test_frontend.py b/tools/library/scripts/pycutlass/test/frontend/test_frontend.py index ca6760cc97..8eaf42f6b8 100644 --- a/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +++ b/tools/library/scripts/pycutlass/test/frontend/test_frontend.py @@ -29,13 +29,15 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# -## Test case for Pytorch + +""" +Test cases for frontends +""" + import pycutlass import unittest from pycutlass import * from pycutlass.utils.device import device_cc -import torch -import cupy as cp class Test_Frontend(unittest.TestCase): @@ -49,9 +51,7 @@ def setUp(self) -> None: cutlass.OpClass.Simt, MathOperation.multiply_add ) - # Stages > 2 is supported only for compute capability 80 and beyond - stages = 4 if cc >= 80 else 2 - + stages = 2 tile_description = TileDescription( [128, 128, 8], stages, [2, 4, 1], math_inst @@ -84,6 +84,11 @@ def setUp(self) -> None: def test_torch_frontend(self): + try: + import torch + except: + self.assertTrue(False, "Unable to import torch") + problem_size = cutlass.gemm.GemmCoord(512, 256, 128) tensor_A = torch.ceil(torch.empty(size=(problem_size.m(), problem_size.k()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) @@ -111,6 +116,11 @@ def test_torch_frontend(self): self.assertTrue(torch.equal(tensor_D, tensor_D_ref)) def test_cupy_frontend(self): + try: + import cupy as cp + except: + self.assertTrue(False, "Unable to import cupy") + cp.cuda.set_allocator(rmm.rmm_cupy_allocator) problem_size = cutlass.gemm.GemmCoord(512, 256, 128) @@ -139,7 +149,6 @@ def test_cupy_frontend(self): self.assertTrue(cp.array_equal(tensor_D, tensor_D_ref)) - if __name__ == '__main__': pycutlass.get_memory_pool(2**32, 2**32) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py index b03e2431bb..de81e4b0b6 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + import pycutlass from pycutlass import * from pycutlass.test import * @@ -92,5 +124,5 @@ def test_SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32_128x256x64_64x64x64(se self.assertTrue(test_all_gemm(operation, "multistage")) if __name__ == '__main__': - pycutlass.get_memory_pool(2**24, 2**24) + pycutlass.get_memory_pool(2**30, 2**30) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm90.py b/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm90.py new file mode 100644 index 0000000000..9237326aa3 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm90.py @@ -0,0 +1,138 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from functools import partial +import pycutlass +from pycutlass import * +from pycutlass import library +from pycutlass.test import * +import unittest + +from pycutlass.test.utils import LayoutCombination, get_name +from pycutlass.test.gemm_testbed import test_all_gemm +from pycutlass.utils.device import device_cc + + +name_fn = partial(get_name, element_a=cutlass.bfloat16, element_b=cutlass.bfloat16, arch=90) + +def add_test(cls, layouts, alignments, element_output, element_accumulator, element_epilogue, + cluster_shape, threadblock_shape, stages, opclass, persistent=False): + """ + Create a test-running function with the given specification and set it as a method of `cls`. + + :param cls: class to which the generated method will be added + :type cls: type + :param layouts: indexable container of layouts of A, B, and C operands + :param alignments: indexable container of alingments of A, B, and C operands + :param element_output: data type of the output element + :param element_accumulator: data type used in accumulation + :param element_epilogue: data type used in computing the epilogue + :param cluster_shape: indexable container of dimensions of threadblock cluster to be launched + :param threadblock_shape: indexable container of dimensions of threadblock tiles + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass.OpClass + :param persistent: whether this is a persistent warp-specialized kernel + :type persistent: bool + """ + + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + element_A = cutlass.bfloat16 + element_B = cutlass.bfloat16 + inst_shape = [1, 1, 1] if opclass == cutlass.OpClass.Simt else None + warp_count = [2, 2, 1] if opclass == cutlass.OpClass.Simt else None + math_inst = MathInstruction( + instruction_shape=inst_shape, + element_a=element_A, element_b=element_B, element_accumulator=element_accumulator, + opcode_class=opclass, math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=threadblock_shape, + cluster_shape=cluster_shape, + stages=stages, warp_count=warp_count, + math_instruction=math_inst, + persistent=persistent + ) + + A = TensorDescription(element=element_A, layout=layouts[0], alignment=alignments[0]) + B = TensorDescription(element=element_B, layout=layouts[1], alignment=alignments[1]) + C = TensorDescription(element=element_output, layout=layouts[2], alignment=alignments[2]) + + epilogue_functor = LinearCombination(C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=90, tile_description=tile_description, A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor) + + self.assertTrue(test_all_gemm(operation, "universal")) + + if persistent: + suffix = "_persistent" + else: + suffix = "" + + name = name_fn(layouts, alignments, element_output, element_accumulator, + element_epilogue, cluster_shape, threadblock_shape, stages, opclass=opclass, suffix=suffix) + setattr(cls, name, run) + + return run + + +@unittest.skipIf(device_cc() < 90, "Device compute capability is insufficient for SM90 tests.") +class GemmBF16Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_tensorop = partial(add_test, opclass=cutlass.OpClass.TensorOp) +add_test_simt = partial(add_test, opclass=cutlass.OpClass.Simt) + +add_test_tensorop(GemmBF16Sm90, LayoutCombination.NNN, [8, 8, 8], cutlass.bfloat16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], 3) +add_test_tensorop(GemmBF16Sm90, LayoutCombination.NNN, [4, 4, 8], cutlass.bfloat16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], 5) +add_test_tensorop(GemmBF16Sm90, LayoutCombination.TNN, [8, 8, 8], cutlass.bfloat16, cutlass.float32, cutlass.float32, [2, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmBF16Sm90, LayoutCombination.TNN, [8, 8, 8], cutlass.bfloat16, cutlass.float32, cutlass.float32, [2, 1, 1], [128, 128, 32], None, persistent=True) +add_test_simt(GemmBF16Sm90, LayoutCombination.NNN, [1, 1, 1], cutlass.bfloat16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 8], 2) + + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**30, 2**30) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py index 6ffb04a5ad..b4f245e36e 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + import pycutlass from pycutlass import * from pycutlass.test import * @@ -443,5 +475,5 @@ def test_SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32_128x256x64_64x64x64(self) if __name__ == '__main__': - pycutlass.get_memory_pool(2**24, 2**24) + pycutlass.get_memory_pool(2**30, 2**30) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm90.py b/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm90.py new file mode 100644 index 0000000000..81540b35ce --- /dev/null +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm90.py @@ -0,0 +1,182 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from functools import partial +import pycutlass +from pycutlass import * +from pycutlass import library +from pycutlass.test import * +import unittest + +from pycutlass.test.utils import LayoutCombination, get_name +from pycutlass.test.gemm_testbed import test_all_gemm +from pycutlass.utils.device import device_cc + + +# Partial specialziation for naming tests +name_fn = partial(get_name, element_a=cutlass.float16, element_b=cutlass.float16, arch=90) + + +def add_test(cls, layouts, alignments, element_output, element_accumulator, element_epilogue, + cluster_shape, threadblock_shape, stages, opclass, persistent=False): + """ + Create a test-running function with the given specification and set it as a method of `cls`. + + :param cls: class to which the generated method will be added + :type cls: type + :param layouts: indexable container of layouts of A, B, and C operands + :param alignments: indexable container of alingments of A, B, and C operands + :param element_output: data type of the output element + :param element_accumulator: data type used in accumulation + :param element_epilogue: data type used in computing the epilogue + :param cluster_shape: indexable container of dimensions of threadblock cluster to be launched + :param threadblock_shape: indexable container of dimensions of threadblock tiles + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass.OpClass + :param persistent: whether this is a persistent warp-specialized kernel + :type persistent: bool + """ + + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + + element_A = cutlass.float16 + element_B = cutlass.float16 + inst_shape = [1, 1, 1] if opclass == cutlass.OpClass.Simt else None + warp_count = [2, 2, 1] if opclass == cutlass.OpClass.Simt else None + math_inst = MathInstruction( + instruction_shape=inst_shape, + element_a=element_A, element_b=element_B, element_accumulator=element_accumulator, + opcode_class=opclass, math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=threadblock_shape, + cluster_shape=cluster_shape, + stages=stages, warp_count=warp_count, + math_instruction=math_inst, + persistent=persistent + ) + + A = TensorDescription(element=element_A, layout=layouts[0], alignment=alignments[0]) + B = TensorDescription(element=element_B, layout=layouts[1], alignment=alignments[1]) + C = TensorDescription(element=element_output, layout=layouts[2], alignment=alignments[2]) + + epilogue_functor = LinearCombination(C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=90, tile_description=tile_description, A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor) + + self.assertTrue(test_all_gemm(operation, "universal")) + + if persistent: + suffix = "_persistent" + else: + suffix = "" + + name = name_fn(layouts, alignments, element_output, element_accumulator, + element_epilogue, cluster_shape, threadblock_shape, stages, opclass=opclass, suffix=suffix) + setattr(cls, name, run) + + return run + + +@unittest.skipIf(device_cc() < 90, "Device compute capability is insufficient for SM90 tests.") +class GemmF16Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_tensorop = partial(add_test, opclass=cutlass.OpClass.TensorOp) +add_test_simt = partial(add_test, opclass=cutlass.OpClass.Simt) + +# Tests with 1x1x1 clusters +add_test_tensorop(GemmF16Sm90, LayoutCombination.NNN, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], 3) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NNT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NTN, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NTT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNN, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [64, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 64, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [64, 64, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [4, 4, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [4, 4, 8], cutlass.float16, cutlass.float16, cutlass.float16, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.float16, cutlass.float16, cutlass.float16, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [64, 64, 64], 5) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [2, 2, 2], cutlass.float16, cutlass.float16, cutlass.float16, [1, 1, 1], [128, 128, 32], None) + +# Tests with different cluster shapes +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 2, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 2, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 2, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NNN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 2, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [1, 4, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 4, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [4, 1, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [4, 2, 1], [64, 128, 64], None) + +# Tests for persistent warp-specialized threadblocks +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [1, 1, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 1, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 1, 1], [128, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [1, 2, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 2, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [1, 4, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 4, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [4, 1, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [4, 4, 1], [64, 128, 64], None, persistent=True) + +# Tests using SIMT +add_test_simt(GemmF16Sm90, LayoutCombination.NNN, [1, 1, 1], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 8], 2) +add_test_simt(GemmF16Sm90, LayoutCombination.TNN, [1, 1, 1], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [64, 128, 8], 2) +add_test_simt(GemmF16Sm90, LayoutCombination.NTN, [1, 1, 1], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 64, 8], 2) +add_test_simt(GemmF16Sm90, LayoutCombination.TTN, [1, 1, 1], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [64, 64, 8], 2) +add_test_simt(GemmF16Sm90, LayoutCombination.NNT, [1, 1, 1], cutlass.float16, cutlass.float16, cutlass.float16, [1, 1, 1], [128, 128, 8], 2) + + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**30, 2**30) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py index ad48d0ddec..0bdf008466 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + import pycutlass from pycutlass import * from pycutlass.memory_manager import get_allocated_size diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py index 11d2668365..4e1aff7086 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + import pycutlass from pycutlass import * from pycutlass.test import * @@ -98,5 +130,5 @@ def test_SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64_64x64x16_32x32x16(self): self.assertTrue(test_all_gemm(operation, "universal")) if __name__ == '__main__': - pycutlass.get_memory_pool(2**24, 2**24) + pycutlass.get_memory_pool(2**30, 2**30) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm90.py b/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm90.py new file mode 100644 index 0000000000..4140ed4a1f --- /dev/null +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm90.py @@ -0,0 +1,124 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from functools import partial +import pycutlass +from pycutlass import * +from pycutlass import library +from pycutlass.test import * +import unittest + +from pycutlass.test.utils import LayoutCombination, get_name +from pycutlass.test.gemm_testbed import test_all_gemm +from pycutlass.utils.device import device_cc + + +name_fn = partial(get_name, element_a=cutlass.float64, element_b=cutlass.float64, arch=90) + +def add_test(cls, layouts, alignments, element_output, element_accumulator, element_epilogue, + cluster_shape, threadblock_shape, stages, opclass): + """ + Create a test-running function with the given specification and set it as a method of `cls`. + + :param cls: class to which the generated method will be added + :type cls: type + :param layouts: indexable container of layouts of A, B, and C operands + :param alignments: indexable container of alingments of A, B, and C operands + :param element_output: data type of the output element + :param element_accumulator: data type used in accumulation + :param element_epilogue: data type used in computing the epilogue + :param cluster_shape: indexable container of dimensions of threadblock cluster to be launched + :param threadblock_shape: indexable container of dimensions of threadblock tiles + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass.OpClass + """ + + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + element_A = cutlass.float64 + element_B = cutlass.float64 + inst_shape = [1, 1, 1] if opclass == cutlass.OpClass.Simt else None + warp_count = [2, 2, 1] if opclass == cutlass.OpClass.Simt else None + math_inst = MathInstruction( + instruction_shape=inst_shape, + element_a=element_A, element_b=element_B, element_accumulator=element_accumulator, + opcode_class=opclass, math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=threadblock_shape, + cluster_shape=cluster_shape, + stages=stages, warp_count=warp_count, + math_instruction=math_inst + ) + + A = TensorDescription(element=element_A, layout=layouts[0], alignment=alignments[0]) + B = TensorDescription(element=element_B, layout=layouts[1], alignment=alignments[1]) + C = TensorDescription(element=element_output, layout=layouts[2], alignment=alignments[2]) + + epilogue_functor = LinearCombination(C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=90, tile_description=tile_description, A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor) + + self.assertTrue(test_all_gemm(operation, "universal")) + + name = name_fn(layouts, alignments, element_output, element_accumulator, + element_epilogue, cluster_shape, threadblock_shape, stages, opclass=opclass) + setattr(cls, name, run) + + return run + + +@unittest.skipIf(device_cc() < 90, "Device compute capability is insufficient for SM90 tests.") +class GemmF64Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_simt = partial(add_test, opclass=cutlass.OpClass.Simt) +add_test_simt(GemmF64Sm90, LayoutCombination.NNN, [1, 1, 1], cutlass.float64, cutlass.float64, cutlass.float64, [1, 1, 1], [64, 64, 32], 2) + + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**30, 2**30) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py index c7acd74296..a1ee9ed36d 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + import pycutlass from pycutlass import * from pycutlass.test import * @@ -199,5 +231,5 @@ def test_SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32_128x128x32_64x64x3 if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + pycutlass.get_memory_pool(2**30, 2**30) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py index 7ddeebbca7..552b3becdb 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + import pycutlass from pycutlass import * from pycutlass.epilogue import LinearCombinationClamp @@ -225,5 +257,5 @@ def test_SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32_128x128x128_64x64x128(self) if __name__ == '__main__': - pycutlass.get_memory_pool(2**24, 2**24) + pycutlass.get_memory_pool(2**30, 2**30) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm90.py b/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm90.py new file mode 100644 index 0000000000..e06d538f82 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm90.py @@ -0,0 +1,154 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from functools import partial +import pycutlass +from pycutlass import * +from pycutlass import library +from pycutlass.test import * +import unittest + +from pycutlass.test.utils import LayoutCombination, get_name +from pycutlass.test.gemm_testbed import test_all_gemm +from pycutlass.utils.device import device_cc + + +name_fn = partial(get_name, element_a=cutlass.float16, element_b=cutlass.float16, arch=90) + +def add_test(cls, layouts, alignments, element_output, element_accumulator, element_epilogue, + cluster_shape, threadblock_shape, stages, opclass, persistent=False): + """ + Create a test-running function with the given specification and set it as a method of `cls`. + + :param cls: class to which the generated method will be added + :type cls: type + :param layouts: indexable container of layouts of A, B, and C operands + :param alignments: indexable container of alingments of A, B, and C operands + :param element_output: data type of the output element + :param element_accumulator: data type used in accumulation + :param element_epilogue: data type used in computing the epilogue + :param cluster_shape: indexable container of dimensions of threadblock cluster to be launched + :param threadblock_shape: indexable container of dimensions of threadblock tiles + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass.OpClass + :param persistent: whether this is a persistent warp-specialized kernel + :type persistent: bool + """ + + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + element_A = cutlass.int8 + element_B = cutlass.int8 + inst_shape = [1, 1, 1] if opclass == cutlass.OpClass.Simt else None + warp_count = [2, 2, 1] if opclass == cutlass.OpClass.Simt else None + math_inst = MathInstruction( + instruction_shape=inst_shape, + element_a=element_A, element_b=element_B, element_accumulator=element_accumulator, + opcode_class=opclass, math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=threadblock_shape, + cluster_shape=cluster_shape, + stages=stages, warp_count=warp_count, + math_instruction=math_inst, + persistent=persistent + ) + + A = TensorDescription(element=element_A, layout=layouts[0], alignment=alignments[0]) + B = TensorDescription(element=element_B, layout=layouts[1], alignment=alignments[1]) + C = TensorDescription(element=element_output, layout=layouts[2], alignment=alignments[2]) + + if opclass == cutlass.OpClass.Simt: + epilogue_functor_cls = LinearCombinationClamp + else: + epilogue_functor_cls = LinearCombination + epilogue_functor = epilogue_functor_cls(C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=90, tile_description=tile_description, A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor) + + self.assertTrue(test_all_gemm(operation, "universal")) + + if persistent: + suffix = "_persistent" + else: + suffix = "" + + name = name_fn(layouts, alignments, element_output, element_accumulator, + element_epilogue, cluster_shape, threadblock_shape, stages, opclass=opclass, suffix=suffix) + setattr(cls, name, run) + + return run + + +@unittest.skipIf(device_cc() < 90, "Device compute capability is insufficient for SM90 tests.") +class GemmS8Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_tensorop = partial(add_test, opclass=cutlass.OpClass.TensorOp) +add_test_simt = partial(add_test, opclass=cutlass.OpClass.Simt) + +# Tests with 1x1x1 clusters +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNN, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [1, 1, 1], [128, 128, 128], 3) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [1, 1, 1], [128, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 8], cutlass.int8, cutlass.int32, cutlass.int32, [1, 1, 1], [128, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [1, 1, 1], [64, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [1, 1, 1], [128, 64, 32], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [4, 4, 16], cutlass.int8, cutlass.int32, cutlass.int32, [1, 1, 1], [128, 128, 128], None) + +# Tests with different cluster shapes +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [2, 2, 1], [128, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [1, 4, 1], [128, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [4, 4, 1], [128, 128, 128], None) + +# Tests with persistent warp-specialized threadblocks +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [2, 1, 1], [128, 128, 128], None, persistent=True) + +# Tests for SIMT +add_test_simt(GemmS8Sm90, LayoutCombination.TNN, [1, 1, 1], cutlass.int8, cutlass.int32, cutlass.int32, [1, 1, 1], [64, 32, 8], 2) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**30, 2**30) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py b/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py index 8a87444675..38f040b19f 100644 --- a/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +++ b/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py @@ -1,8 +1,40 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + import pycutlass import unittest if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + pycutlass.get_memory_pool(2**30, 2**30) loader = unittest.TestLoader() tests = loader.discover('./', 'gemm_*.py') testRunner = unittest.runner.TextTestRunner() diff --git a/tools/library/src/gemm_operation_3x.hpp b/tools/library/src/gemm_operation_3x.hpp new file mode 100644 index 0000000000..895de5bece --- /dev/null +++ b/tools/library/src/gemm_operation_3x.hpp @@ -0,0 +1,292 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/library/library.h" +#include "library_internal.h" + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperation3xBase : public Operation { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + // assuming all tensors use same type for StrideIndex + using StrideIndex = typename Operator::LayoutA::Index; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::CollectiveEpilogue::ElementCompute; + +private: + + GemmDescription description_; + +public: + + /// Constructor + GemmOperation3xBase(char const *name = "unknown_gemm", GemmKind gemm_kind_ = GemmKind::kGemm) { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kGemm; + description_.gemm_kind = gemm_kind_; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + description_.tile_description.cluster_shape = make_Coord( + Operator::ClusterShape::kM, + Operator::ClusterShape::kN, + Operator::ClusterShape::kK); + } + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversal3xOperation : public GemmOperation3xBase { +public: + + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + +public: + + /// Constructor + GemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) { + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + if (operator_args.hw_info.sm_count <= 0) { + operator_args.hw_info.sm_count = KernelHardwareInfo::query_device_multiprocessor_count(); + } + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, GemmUniversalArguments const *arguments) { + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename ThreadEpilogueOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta)); + operator_args.epilogue_params.thread_params = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice) { + typename ThreadEpilogueOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta)); + operator_args.epilogue_params.thread_params = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // TODO: type erase Arguments structure in 3.0 GEMM + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + operator_args.ptr_A = static_cast(arguments->A); + operator_args.ptr_B = static_cast(arguments->B); + operator_args.epilogue_params.ptr_C = static_cast(arguments->C); + operator_args.epilogue_params.ptr_D = static_cast(arguments->D); + + operator_args.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + operator_args.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + operator_args.epilogue_params.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue_params.dD = operator_args.epilogue_params.dC; + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments args; + Status status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(args, device_workspace, stream); + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu index ef7904a030..fdfe2516cd 100644 --- a/tools/library/src/handle.cu +++ b/tools/library/src/handle.cu @@ -622,6 +622,8 @@ Status Handle::gemm_universal( char host_workspace[kHostWorkspaceSize]; GemmUniversalArguments arguments{ + {M, N, K}, + batch_count, ptr_A, ptr_B, ptr_C, @@ -629,6 +631,10 @@ Status Handle::gemm_universal( alpha, beta, scalar_pointer_mode_, + lda, + ldb, + ldc, + ldd, batch_stride_A, batch_stride_B, batch_stride_C, diff --git a/tools/profiler/CMakeLists.txt b/tools/profiler/CMakeLists.txt index 1692e95efe..f9f425ccbc 100644 --- a/tools/profiler/CMakeLists.txt +++ b/tools/profiler/CMakeLists.txt @@ -86,6 +86,7 @@ target_link_libraries( $<$:nvidia::cublas> $<$:nvidia::cudnn> cudart + cuda_driver ) install( diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index 911aa9bd82..4b15fda5f3 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -537,6 +537,13 @@ Status GemmOperationProfiler::initialize_workspace( gemm_workspace_.Reference->copy_from_device(gemm_workspace_.C->data()); + // NOTE: the leading non-batch strides are duplicated here for 3.0 API kernels + gemm_workspace_.arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)}; + gemm_workspace_.arguments.batch_count = problem_.batch_count; + gemm_workspace_.arguments.lda = problem_.lda; + gemm_workspace_.arguments.ldb = problem_.ldb; + gemm_workspace_.arguments.ldc = problem_.ldc; + gemm_workspace_.arguments.ldd = problem_.ldc; gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); diff --git a/tools/profiler/src/operation_profiler.cu b/tools/profiler/src/operation_profiler.cu index acf852861f..b2e8f9b746 100644 --- a/tools/profiler/src/operation_profiler.cu +++ b/tools/profiler/src/operation_profiler.cu @@ -75,6 +75,9 @@ OperationProfiler::OperationProfiler( {ArgumentTypeID::kInteger, {"cta_m", "threadblock-shape::m"}, "Threadblock shape in the M dimension"}, {ArgumentTypeID::kInteger, {"cta_n", "threadblock-shape::n"}, "Threadblock shape in the N dimension"}, {ArgumentTypeID::kInteger, {"cta_k", "threadblock-shape::k"}, "Threadblock shape in the K dimension"}, + {ArgumentTypeID::kInteger, {"cluster_m", "cluster-shape::m"}, "Cluster shape in the M dimension"}, + {ArgumentTypeID::kInteger, {"cluster_n", "cluster-shape::n"}, "Cluster shape in the N dimension"}, + {ArgumentTypeID::kInteger, {"cluster_k", "cluster-shape::k"}, "Cluster shape in the K dimension"}, {ArgumentTypeID::kInteger, {"stages", "threadblock-stages"}, "Number of stages of threadblock-scoped matrix multiply"}, {ArgumentTypeID::kInteger, {"warps_m", "warp-count::m"}, "Number of warps within threadblock along the M dimension"}, {ArgumentTypeID::kInteger, {"warps_n", "warp-count::n"}, "Number of warps within threadblock along the N dimension"}, @@ -198,6 +201,24 @@ bool OperationProfiler::satisfies( } } + if (arg_as_int(int_value, "cluster_m", problem_space, problem)) { + if (int64_t(op_desc.tile_description.cluster_shape.m()) != int_value) { + return false; + } + } + + if (arg_as_int(int_value, "cluster_n", problem_space, problem)) { + if (int64_t(op_desc.tile_description.cluster_shape.n()) != int_value) { + return false; + } + } + + if (arg_as_int(int_value, "cluster_k", problem_space, problem)) { + if (int64_t(op_desc.tile_description.cluster_shape.k()) != int_value) { + return false; + } + } + if (arg_as_int(int_value, "stages", problem_space, problem)) { if (int64_t(op_desc.tile_description.threadblock_stages) != int_value) { return false; @@ -596,6 +617,9 @@ void OperationProfiler::initialize_result_( set_argument(result, "cta_m", problem_space, operation_desc.tile_description.threadblock_shape.m()); set_argument(result, "cta_n", problem_space, operation_desc.tile_description.threadblock_shape.n()); set_argument(result, "cta_k", problem_space, operation_desc.tile_description.threadblock_shape.k()); + set_argument(result, "cluster_m", problem_space, operation_desc.tile_description.cluster_shape.m()); + set_argument(result, "cluster_n", problem_space, operation_desc.tile_description.cluster_shape.n()); + set_argument(result, "cluster_k", problem_space, operation_desc.tile_description.cluster_shape.k()); set_argument(result, "stages", problem_space, operation_desc.tile_description.threadblock_stages); set_argument(result, "warps_m", problem_space, operation_desc.tile_description.warp_count.m()); set_argument(result, "warps_n", problem_space, operation_desc.tile_description.warp_count.n()); diff --git a/tools/util/include/cutlass/util/GPU_Clock.hpp b/tools/util/include/cutlass/util/GPU_Clock.hpp new file mode 100644 index 0000000000..5f2dd4bd14 --- /dev/null +++ b/tools/util/include/cutlass/util/GPU_Clock.hpp @@ -0,0 +1,67 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +struct GPU_Clock +{ + GPU_Clock() { + cudaEventCreate(&start_); + cudaEventCreate(&stop_); + cudaEventRecord(start_); + } + + ~GPU_Clock() { + cudaEventDestroy(start_); + cudaEventDestroy(stop_); + } + + void start() { + cudaEventRecord(start_); + } + + float milliseconds() { + cudaEventRecord(stop_); + cudaEventSynchronize(stop_); + float time; + cudaEventElapsedTime(&time, start_, stop_); + return time; + } + + float seconds() { + return milliseconds() * float(1e-3); + } + + private: + cudaEvent_t start_, stop_; +}; diff --git a/tools/util/include/cutlass/util/cublas_wrappers.hpp b/tools/util/include/cutlass/util/cublas_wrappers.hpp new file mode 100644 index 0000000000..82d56fa18f --- /dev/null +++ b/tools/util/include/cutlass/util/cublas_wrappers.hpp @@ -0,0 +1,526 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +//-- BLAM_DEBUG_OUT --------------------------------------------------------- +#ifdef BLAM_DEBUG +# include +# ifndef BLAM_DEBUG_OUT +# define BLAM_DEBUG_OUT(msg) std::cerr << "BLAM: " << msg << std::endl +# define BLAM_DEBUG_OUT_2(msg) std::cerr << msg << std::endl +# endif // BLAM_DEBUG_OUT +#else +# ifndef BLAM_DEBUG_OUT +# define BLAM_DEBUG_OUT(msg) +# define BLAM_DEBUG_OUT_2(msg) +# endif // BLAM_DEBUG_OUT +#endif // BLAM_DEBUG + +// User could potentially define ComplexFloat/ComplexDouble instead of std:: +#ifndef BLAM_COMPLEX_TYPES +#define BLAM_COMPLEX_TYPES 1 +#include +namespace blam { +template +using Complex = cuda::std::complex; +using ComplexFloat = cuda::std::complex; +using ComplexDouble = cuda::std::complex; +} +#endif // BLAM_COMPLEX_TYPES + +// User could potentially define Half instead of cute:: +#ifndef BLAM_HALF_TYPE +#define BLAM_HALF_TYPE 1 +#include +namespace blam { +using Half = cute::half_t; +} +#endif // BLAM_HALF_TYPE + +namespace blam +{ +namespace cublas +{ + +inline const char* +cublas_get_error(cublasStatus_t status) +{ + switch (status) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED -- The cuBLAS library was not initialized."; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED -- Resource allocation failed inside the cuBLAS library."; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE -- An unsupported value or parameter was passed to the function."; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH -- The function requires a feature absent from the device architecture."; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR -- An access to GPU memory space failed."; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED -- The GPU program failed to execute."; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR -- An internal cuBLAS operation failed."; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED -- The functionality requested is not supported."; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR -- An error was detected when checking the current licensing."; + default: + return "CUBLAS_ERROR -- "; + } +} + +inline bool +cublas_is_error(cublasStatus_t status) +{ + return status != CUBLAS_STATUS_SUCCESS; +} + + +// hgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* A, int ldA, + const Half* B, int ldB, + const Half* beta, + Half* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasHgemm"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), CUDA_R_16F, ldA, + reinterpret_cast(B), CUDA_R_16F, ldB, + reinterpret_cast(beta), + reinterpret_cast< __half*>(C), CUDA_R_16F, ldC, + CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// mixed hf gemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const Half* A, int ldA, + const Half* B, int ldB, + const float* beta, + float* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasGemmEx mixed half-float"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + alpha, + reinterpret_cast(A), CUDA_R_16F, ldA, + reinterpret_cast(B), CUDA_R_16F, ldB, + beta, + C, CUDA_R_32F, ldC, + CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// igemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const int32_t* alpha, + const int8_t* A, int ldA, + const int8_t* B, int ldB, + const int32_t* beta, + int32_t* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasIgemm"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + alpha, + A, CUDA_R_8I, ldA, + B, CUDA_R_8I, ldB, + beta, + C, CUDA_R_32I, ldC, + CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// sgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* A, int ldA, + const float* B, int ldB, + const float* beta, + float* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasSgemm"); + + return cublasSgemm(handle, transA, transB, + m, n, k, + alpha, + A, ldA, + B, ldB, + beta, + C, ldC); +} + +// dgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* A, int ldA, + const double* B, int ldB, + const double* beta, + double* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasDgemm"); + + return cublasDgemm(handle, transA, transB, + m, n, k, + alpha, + A, ldA, + B, ldB, + beta, + C, ldC); +} + +// cgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* A, int ldA, + const ComplexFloat* B, int ldB, + const ComplexFloat* beta, + ComplexFloat* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasCgemm"); + + return cublasCgemm(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, + reinterpret_cast(B), ldB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC); +} + +// zgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* A, int ldA, + const ComplexDouble* B, int ldB, + const ComplexDouble* beta, + ComplexDouble* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasZgemm"); + + return cublasZgemm(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, + reinterpret_cast(B), ldB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC); +} + +// hgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* A, int ldA, int loA, + const Half* B, int ldB, int loB, + const Half* beta, + Half* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasHgemmStridedBatched"); + + return cublasHgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast<__half*>(C), ldC, loC, + batch_size); +} + +// sgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* A, int ldA, int loA, + const float* B, int ldB, int loB, + const float* beta, + float* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasSgemmStridedBatched"); + + return cublasSgemmStridedBatched(handle, transA, transB, + m, n, k, + alpha, + A, ldA, loA, + B, ldB, loB, + beta, + C, ldC, loC, + batch_size); +} + +// dgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* A, int ldA, int loA, + const double* B, int ldB, int loB, + const double* beta, + double* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasDgemmStridedBatched"); + + return cublasDgemmStridedBatched(handle, transA, transB, + m, n, k, + alpha, + A, ldA, loA, + B, ldB, loB, + beta, + C, ldC, loC, + batch_size); +} + +// cgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* A, int ldA, int loA, + const ComplexFloat* B, int ldB, int loB, + const ComplexFloat* beta, + ComplexFloat* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasCgemmStridedBatched"); + + return cublasCgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC, loC, + batch_size); +} + +// zgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* A, int ldA, int loA, + const ComplexDouble* B, int ldB, int loB, + const ComplexDouble* beta, + ComplexDouble* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasZgemmStridedBatched"); + + return cublasZgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC, loC, + batch_size); +} + +// hgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* const A[], int ldA, + const Half* const B[], int ldB, + const Half* beta, + Half* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasHgemmBatched"); + + return cublasHgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(const_cast(A)), ldA, + // A, ldA, // cuBLAS 9.2 + reinterpret_cast(const_cast(B)), ldB, + // B, ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + reinterpret_cast<__half**>(const_cast(C)), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// sgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* const A[], int ldA, + const float* const B[], int ldB, + const float* beta, + float* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasSgemmBatched"); + + return cublasSgemmBatched(handle, transA, transB, + m, n, k, + alpha, + const_cast(A), ldA, + // A, ldA, // cuBLAS 9.2 + const_cast(B), ldB, + // B, ldB, // cuBLAS 9.2 + beta, + const_cast(C), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// dgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* const A[], int ldA, + const double* const B[], int ldB, + const double* beta, + double* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasDgemmBatched"); + + return cublasDgemmBatched(handle, transA, transB, + m, n, k, + alpha, + const_cast(A), ldA, + // A, ldA, // cuBLAS 9.2 + const_cast(B), ldB, + // B, ldB, // cuBLAS 9.2 + beta, + const_cast(C), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// cgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* const A[], int ldA, + const ComplexFloat* const B[], int ldB, + const ComplexFloat* beta, + ComplexFloat* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasCgemmBatched"); + + return cublasCgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + const_cast(reinterpret_cast(A)), ldA, + //reinterpret_cast(A), ldA, // cuBLAS 9.2 + const_cast(reinterpret_cast(B)), ldB, + //reinterpret_cast(B), ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + const_cast(reinterpret_cast(C)), ldC, + //reinterpret_cast(C), ldC, // cuBLAS 9.2 + batch_size); +} + +// zgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* const A[], int ldA, + const ComplexDouble* const B[], int ldB, + const ComplexDouble* beta, + ComplexDouble* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasZgemmBatched"); + + return cublasZgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + const_cast(reinterpret_cast(A)), ldA, + //reinterpret_cast(A), ldA, // cuBLAS 9.2 + const_cast(reinterpret_cast(B)), ldB, + //reinterpret_cast(B), ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + const_cast(reinterpret_cast(C)), ldC, + //reinterpret_cast(C), ldC, // cuBLAS 9.2 + batch_size); +} + +} // end namespace cublas +} // end namespace blam diff --git a/tools/util/include/cutlass/util/device_layernorm.h b/tools/util/include/cutlass/util/device_layernorm.h index 6305dffd8e..c4ec9251bb 100644 --- a/tools/util/include/cutlass/util/device_layernorm.h +++ b/tools/util/include/cutlass/util/device_layernorm.h @@ -456,7 +456,7 @@ void layernorm(cutlass::MatrixCoord tensor_size, block.x = 1024; } // TODO : There should be better configs for different cases, we only use several samples to show how to use here - // TODO : using registers to store values locally can reduce the ldgs from global memory and speedup the kernels. + // TODO : using registers to store values locally can reduce the loads from global memory and speedup the kernels. if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) { block.x = (n/4 + 31)/32*32; if (std::is_same::value) { diff --git a/tools/util/include/cutlass/util/helper_cuda.hpp b/tools/util/include/cutlass/util/helper_cuda.hpp new file mode 100644 index 0000000000..15e0bc8540 --- /dev/null +++ b/tools/util/include/cutlass/util/helper_cuda.hpp @@ -0,0 +1,116 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include + +namespace cute +{ + +void +device_init(int device_id, bool quiet = false) +{ + cudaDeviceProp device_prop; + std::size_t device_free_physmem; + std::size_t device_total_physmem; + + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaMemGetInfo(&device_free_physmem, &device_total_physmem)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + + //float device_giga_bandwidth = float(device_prop.memoryBusWidth) * device_prop.memoryClockRate * 2 / 8 / 1000 / 1000; + + if (!quiet) { + printf("Using device %d: %s (SM%d, %d SMs)\n", + device_id, device_prop.name, + device_prop.major * 10 + device_prop.minor, + device_prop.multiProcessorCount); + fflush(stdout); + } +} + +/** + * Convert the SM version (e.g. v7.0, v7.5) to the physical number of cores. + */ +inline int +_ConvertSMVer2Cores(int major, int minor) +{ + // Defines for GPU Architecture types (using the SM version to determine + // the # of cores per SM + typedef struct { + int SM; // 0xMm (hexidecimal notation), M = SM Major version, + // and m = SM minor version + int Cores; + } sSMtoCores; + + sSMtoCores nGpuArchCoresPerSM[] = { + {0x30, 192}, + {0x32, 192}, + {0x35, 192}, + {0x37, 192}, + {0x50, 128}, + {0x52, 128}, + {0x53, 128}, + {0x60, 64}, + {0x61, 128}, + {0x62, 128}, + {0x70, 64}, + {0x72, 64}, + {0x75, 64}, + {-1, -1}}; + + int index = 0; + + while (nGpuArchCoresPerSM[index].SM != -1) { + if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { + return nGpuArchCoresPerSM[index].Cores; + } + index++; + } + + // If we don't find the values, we default use the previous one + // to run properly + printf("MapSMtoCores for SM %d.%d is undefined." + " Default to use %d Cores/SM\n", + major, minor, nGpuArchCoresPerSM[index - 1].Cores); + + return nGpuArchCoresPerSM[index - 1].Cores; +} + +} // end namespace cute diff --git a/tools/util/include/cutlass/util/packed_stride.hpp b/tools/util/include/cutlass/util/packed_stride.hpp new file mode 100644 index 0000000000..7ecffaffa1 --- /dev/null +++ b/tools/util/include/cutlass/util/packed_stride.hpp @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for packing a rank-X shape into a rank-(X-1) stride in CuTe. +*/ + +#pragma once + +#include "cute/stride.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides without batch mode + +template +cute::Stride> +make_cute_packed_stride(cute::Stride> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +cute::Stride, StrideIntT> +make_cute_packed_stride(cute::Stride, StrideIntT> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides with batch mode + +template +cute::Stride, int64_t> +make_cute_packed_stride(cute::Stride, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +template +cute::Stride, StrideIntT, int64_t> +make_cute_packed_stride(cute::Stride, StrideIntT, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/print_error.hpp b/tools/util/include/cutlass/util/print_error.hpp new file mode 100644 index 0000000000..f867f88e65 --- /dev/null +++ b/tools/util/include/cutlass/util/print_error.hpp @@ -0,0 +1,235 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include + +// The computed infinity norm does not include +// any NaN column absolute-value sums. +struct matrix_inf_norm_result { + // Accumulate errors in double, as this is generally + // the highest precision that the examples use. + double inf_norm = 0.0; + bool found_nan = false; +}; + +// In theory, cute::Tensor, T> could be treated as a view type, +// and thus passed by value (as std::span or std::string_view would be). +// However, generic cute::Tensor are more like containers +// and thus are best passed by reference or const reference. +template +matrix_inf_norm_result +matrix_inf_norm(const cute::Tensor& host_matrix) +{ + using std::abs; + using error_type = decltype(std::declval().inf_norm); + + error_type inf_norm = 0.0; + bool found_nan = false; + + const auto shape = host_matrix.shape(); + using index_type = std::decay_t(shape))>; + // Computing the infinity norm requires that we be able + // to treat the input as a matrix, with rows and columns. + static_assert(std::is_integral_v); + const index_type num_rows = cute::get<0>(shape); + const index_type num_cols = cute::get<1>(shape); + + for(index_type i = 0; i < num_rows; ++i) { + error_type row_abs_sum = 0.0; + for(index_type j = 0; j < num_cols; ++j) { + row_abs_sum += abs(host_matrix(i, j)); + } + if(std::isnan(row_abs_sum)) { + found_nan = true; + } else { + inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; + } + } + + return {inf_norm, found_nan}; +} + +// Infinity norm of (X - Y). +template +matrix_inf_norm_result +matrix_diff_inf_norm(const cute::Tensor& X, + const cute::Tensor& Y) +{ + using std::abs; + using error_type = decltype(std::declval().inf_norm); + + const auto X_shape = X.shape(); + const auto Y_shape = Y.shape(); + + using index_type = std::decay_t(X_shape))>; + // Computing the infinity norm requires that we be able + // to treat the input as a matrix, with rows and columns. + static_assert(std::is_integral_v); + const index_type num_rows = cute::get<0>(X_shape); + const index_type num_cols = cute::get<1>(X_shape); + + assert(num_rows == cute::get<0>(Y_shape)); + assert(num_cols == cute::get<1>(Y_shape)); + + auto matrix_ij = [&](const auto& A, std::size_t i, std::size_t j) { + return A(i, j); + }; + auto diff_ij = [&](std::size_t i, std::size_t j) { + return matrix_ij(X, i, j) - matrix_ij(Y, i, j); + }; + + error_type inf_norm = 0.0; + bool found_nan = false; + + for(index_type i = 0; i < num_rows; ++i) { + error_type row_abs_sum = 0.0; + for(index_type j = 0; j < num_cols; ++j) { + row_abs_sum += abs(diff_ij(i, j)); + } + if(std::isnan(row_abs_sum)) { + found_nan = true; + } else { + inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; + } + } + + return {inf_norm, found_nan}; +} + +template +void +print_matrix_multiply_mollified_relative_error( + const char A_value_type_name[], + const cute::Tensor& A, + const char B_value_type_name[], + const cute::Tensor& B, + const char C_value_type_name[], + const cute::Tensor& C_computed, + const cute::Tensor& C_expected) +{ + const auto [A_norm, A_has_nan] = matrix_inf_norm(A); + const auto [B_norm, B_has_nan] = matrix_inf_norm(B); + const auto [C_norm, C_has_nan] = matrix_inf_norm(C_expected); + const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C_computed, C_expected); + + const auto A_norm_times_B_norm = A_norm * B_norm; + const auto relative_error = A_norm_times_B_norm == 0.0 ? + diff_norm : (diff_norm / A_norm_times_B_norm); + + // For expected error bounds, please refer to the LAPACK Users' Guide, + // in particular https://netlib.org/lapack/lug/node108.html . + // Printing the infinity norm of C is a way to check + // that both the function being tested (C_computed) + // and the reference implementation (C_expected) + // don't just do nothing (or fill with zeros). + using std::cout; + cout << "Value type of A: " << A_value_type_name << '\n' + << std::scientific + << "Infinity norm of A: " << A_norm << '\n' + << "Value type of B: " << B_value_type_name << '\n' + << "Infinity norm of B: " << B_norm << '\n' + << "Value type of C: " << C_value_type_name << '\n' + << "Infinity norm of C_expected: " << C_norm << '\n' + << "Infinity norm of (C_computed - C_expected): " << diff_norm << '\n'; + + if(A_norm_times_B_norm == 0.0) { + cout << "Mollified relative error: " << relative_error << '\n'; + } else { + cout << "Relative error: " << relative_error << '\n'; + } + + cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in C_expected? " << (C_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in (C_computed - C_expected)? " + << (diff_has_nan ? "yes" : "no") << '\n'; +} + +template +void +print_matrix_multiply_mollified_relative_error( + const char value_type_name[], + const cute::Tensor& A, + const cute::Tensor& B, + const cute::Tensor& C_computed, + const cute::Tensor& C_expected) +{ + print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B, + value_type_name, C_computed, C_expected); +} + +// Take a CUTLASS HostTensor (or the like) as input, +// and return a const CuTe Tensor. +// This is useful for use with the above error printing functions. +// This implicitly "transposes" if the layout is RowMajor. +// Note that the HostTensor must be captured by nonconst reference +// in order for X.host_ref().data() to compile. +// (CUTLASS is a bit more container-y than CuTe.) +template +auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X) +{ + // The tensors were created with post-transposed extents. + const auto extents = X.extent(); + const auto shape = cute::Shape{extents[0], extents[1]}; + // Both RowMajor and ColumnMajor only store one stride. + const int LDX = X.stride(0); + const auto strides = [&]() { + using input_layout_type = typename std::decay_t::Layout; + if constexpr (std::is_same_v) { + return cute::Stride{1, LDX}; + } + else { + static_assert(std::is_same_v); + return cute::Stride{LDX, 1}; + } + }(); + const auto layout = cute::make_layout(shape, strides); + auto X_data = X.host_ref().data(); + auto X_data_const = const_cast >(X_data); + return cute::make_tensor(X_data_const, layout); +}; diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp new file mode 100644 index 0000000000..64a0600b64 --- /dev/null +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -0,0 +1,311 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GETT in host-side code. +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::reference::host { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ElementAccumulator_, + class TensorA_, // (M, K, L) + class TensorB_ // (N, K, L) +> +struct GettMainloopParams { + using ElementAccumulator = ElementAccumulator_; + using TensorA = TensorA_; + using TensorB = TensorB_; + using EngineA = typename TensorA::engine_type; + using LayoutA = typename TensorA::layout_type; + using EngineB = typename TensorB::engine_type; + using LayoutB = typename TensorB::layout_type; + + TensorA A{}; + TensorB B{}; + + ComplexTransform transform_A = ComplexTransform::kNone; + ComplexTransform transform_B = ComplexTransform::kNone; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ElementScalar_, + class ElementAccumulator_, + class ElementCompute_, + class TensorC_, // (M, N, L) + class TensorD_ // (M, N, L) +> +struct GettEpilogueParams { + using ElementScalar = ElementScalar_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using TensorC = TensorC_; + using TensorD = TensorD_; + using EngineC = typename TensorC::engine_type; + using LayoutC = typename TensorC::layout_type; + using EngineD = typename TensorD::engine_type; + using LayoutD = typename TensorD::layout_type; + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + + TensorC C{}; + TensorD D{}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - General Tensor-Tensor contraction reference kernel +template < + class MainloopParams, + class EpilogueParams +> +void Gett( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + + #pragma omp parallel for collapse(3) + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + gett_epilogue(epilogue_params, m, n, l, acc); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Mainloop +template +void gett_mainloop( + MainloopParams const& mainloop_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); + static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); + + using ElementA = typename MainloopParams::EngineA::value_type; + using ElementB = typename MainloopParams::EngineB::value_type; + + using RingOp = multiply_add; + RingOp fma_op; + + // Zero out accumulators + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // Compute on this k-block + for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { + // Load A + ElementAccumulator a_frag[kBlockM]; + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { + a_frag[m_b] = static_cast(mainloop_params.A(m + m_b, k, l)); + if (mainloop_params.transform_A == ComplexTransform::kConjugate) { + a_frag[m_b] = conj(a_frag[m_b]); + } + } else { + a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // Load B + ElementAccumulator b_frag[kBlockN]; + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { + b_frag[n_b] = static_cast(mainloop_params.B(n + n_b, k, l)); + if (mainloop_params.transform_B == ComplexTransform::kConjugate) { + b_frag[n_b] = conj(b_frag[n_b]); + } + } else { + b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // do compute + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc[m_b][n_b]); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Epilogue +template +void gett_epilogue( + EpilogueParams const& epilogue_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); + static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); + + using ElementCompute = typename EpilogueParams::ElementCompute; + using ElementC = typename EpilogueParams::EngineC::value_type; + + using ElementD = typename EpilogueParams::EngineD::value_type; + using ElementScalar = typename EpilogueParams::ElementScalar; + // Input related converter + NumericConverter accumulator_converter; + NumericConverter source_converter; + + // Scale related converter + NumericConverter scale_converter; + // Output related converter + NumericConverter destination_converter; + // Epilogue operations + multiply_add epilogue_fma; + multiplies mul; + + // Do conversion + ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); + ElementCompute converted_beta = scale_converter(epilogue_params.beta); + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + // Convert every type to ElementCompute first, do compute, convert to output type, write it out + ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); + ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); + + ElementScalar output = epilogue_fma(converted_alpha, converted_acc, ElementCompute(0)); + output = epilogue_fma(converted_beta, converted_src, output); + + epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(output); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM - General Matrix-Matrix contraction without conjugation options +template < + class MainloopParams, + class EpilogueParams +> +void Gemm3x( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + using namespace cute; + + static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename MainloopParams::LayoutB{})); + static_assert(rank(typename EpilogueParams::LayoutC{}) == rank(typename EpilogueParams::LayoutD{})); + static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename EpilogueParams::LayoutC{})); + + if constexpr (rank(typename MainloopParams::LayoutA{}) == 2) { + // append a batch mode of size 1 if we do not have tensors that are rank 3 + Layout layout_A = make_layout( + make_shape(get<0>(mainloop_params.A.shape()), get<1>(mainloop_params.A.shape()), Int<1>{}), + make_stride(get<0>(mainloop_params.A.stride()), get<1>(mainloop_params.A.stride()), int64_t(cosize(mainloop_params.A.layout())))); + + Layout layout_B = make_layout( + make_shape(get<0>(mainloop_params.B.shape()), get<1>(mainloop_params.B.shape()), Int<1>{}), + make_stride(get<0>(mainloop_params.B.stride()), get<1>(mainloop_params.B.stride()), int64_t(cosize(mainloop_params.B.layout())))); + + Layout layout_C = make_layout( + make_shape(get<0>(epilogue_params.C.shape()), get<1>(epilogue_params.C.shape()), Int<1>{}), + make_stride(get<0>(epilogue_params.C.stride()), get<1>(epilogue_params.C.stride()), int64_t(cosize(epilogue_params.C.layout())))); + + Layout layout_D = make_layout( + make_shape(get<0>(epilogue_params.D.shape()), get<1>(epilogue_params.D.shape()), Int<1>{}), + make_stride(get<0>(epilogue_params.D.stride()), get<1>(epilogue_params.D.stride()), int64_t(cosize(epilogue_params.D.layout())))); + auto TensorA = make_tensor(mainloop_params.A.data(), layout_A); + auto TensorB = make_tensor(mainloop_params.B.data(), layout_B); + auto TensorC = make_tensor(epilogue_params.C.data(), layout_C); + auto TensorD = make_tensor(epilogue_params.D.data(), layout_D); + // Reconstruct mainloop params + GettMainloopParams + mainloop_params_converted{TensorA, + TensorB, + mainloop_params.transform_A, + mainloop_params.transform_B}; + + // Reconstruct epilogue params + GettEpilogueParams + epilogue_params_converted{epilogue_params.alpha, + epilogue_params.beta, + TensorC, + TensorD + }; + + Gett(mainloop_params_converted, epilogue_params_converted); + } + else { + // if we already have a batch mode, just pass it through + Gett(mainloop_params, epilogue_params); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // cutlass::reference::host + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp b/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp new file mode 100644 index 0000000000..a4a5b4e386 --- /dev/null +++ b/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are equal. +template < + typename TensorL, + typename TensorR +> +bool TensorEquals( + TensorL lhs, + TensorR rhs) { + + // Extents must be identical + if (cute::size(lhs) != cute::size(rhs)) { + return false; + } + + for (int64_t idx = 0; idx < cute::size(lhs); ++idx) { + if (lhs(idx) != rhs(idx)) { + return false; + } + } + + return true; +} + +/// Returns true if two tensor views are NOT equal. +template < + typename TensorL, + typename TensorR +> +bool TensorNotEquals( + TensorL lhs, + TensorR rhs) { + + return TensorEquals(lhs, rhs); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp b/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp new file mode 100644 index 0000000000..3262c53527 --- /dev/null +++ b/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp @@ -0,0 +1,432 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Uniform and procedural tensor fills +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a scalar element +template +void TensorFill(Tensor dst, typename Tensor::value_type element) { + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = element; + } +} + +/// Fills a tensor with the contents of its layout +template +void TensorFillSequential(Tensor dst) { + + auto layout = dst.layout(); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = layout(idx); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Random uniform values +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomUniformFunc { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Element operator()() const { + + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (int_scale >= 0) { + rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(Real(rnd)); + } + else { + result = static_cast(Real(rnd)); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + complex operator()() const { + + Element reals[2]; + + for (int i = 0; i < 2; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a Quaternion value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Quaternion operator()() const { + + Element reals[4]; + + for (int i = 0; i < 4; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template ///< Tensor object +void TensorFillRandomUniform( + Tensor dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = random_func(); + } +} + +/// Fills a block with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + detail::RandomUniformFunc random_func(seed, max, min, bits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Random Gaussian +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomGaussianFunc { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1 + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Element operator()() const { + + // Box-Muller transform to generate random numbers with Normal distribution + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + + // Compute Gaussian random value + double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd = mean + stddev * rnd; + + // Scale and convert final result + Element result; + + if (int_scale >= 0) { + rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(rnd); + } + else { + result = static_cast(rnd); + } + + return result; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Tensor +> +void TensorFillRandomGaussian( + Tensor dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = random_func(); + } +} + +/// Fills a block with random values with a Gaussian distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomGaussian( + Element *ptr, ///< destination buffer + size_t capacity, ///< number of elements + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + int i = 0; + + while (i < capacity) { + + ptr[i] = Element(s + v); + ++i; + } +} + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequentialModN( + Element *ptr, + int64_t capacity, + int64_t mod, + int64_t v = int64_t(1), + int64_t s = int64_t(0)) { + int i = 0; + + while (i < capacity) { + + ptr[i] = static_cast(int32_t(int64_t(s + v) % mod)); + ++i; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp b/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp new file mode 100644 index 0000000000..aadf60ac7e --- /dev/null +++ b/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tensor reductions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Tensor, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + Tensor view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform +) { + + for (int64_t idx = 0; idx < cute::size(view); ++idx) { + identity = reduce(identity, transform(view(idx))); + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename TensorA, + typename TensorB, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorA view_A, + TensorB view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform) { + + if (cute::size(view_A) != cute::size(view_B)) { + throw std::runtime_error("Tensor sizes must match."); + } + + for (int64_t idx = 0; idx < cute::size(view_A); ++idx) { + identity = reduce(identity, transform(view_A(idx), view_B(idx))); + } + + return identity; +} + +/// Helper to compute the sum of the elements of a tensor +template < + typename Tensor, + typename ComputeType = typename Tensor::value_type +> +ComputeType TensorSum( + Tensor view, + ComputeType identity = ComputeType() +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Tensor, + typename ComputeType = typename Tensor::value_type +> +ComputeType TensorSumSq( + Tensor view, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Tensor, + typename ComputeType = double +> +ComputeType TensorNorm( + Tensor view, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSq(view, identity)); +} + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename TensorA, + typename TensorB, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorA view_A, + TensorB view_B, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename TensorA, + typename TensorB, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorA view_A, + TensorB view_B, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////////