diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b9f0afc..ddfb751f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,14 +1,15 @@ # NVIDIA CUTLASS Changelog -## [3.4](https://github.com/NVIDIA/cutlass/releases/tag/v3.4) (2023-12-29) +## [3.4](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12) * Expanded [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors. * Performance improvements to [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) * Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above). * Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above). +* [Ampere Sparse GEMM](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now. * NamedBarriers usability improvement and list of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) has been officially released. -* Improved [CuTe TMA Tensor](/media/docs/cute/0z_tma_tensors.md) documentation. +* Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved. -## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3) (2023-10-31) +## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31) * [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types. * [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}. * [Copy Async based Hopper GEMMs](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors. @@ -20,7 +21,7 @@ * Fusion support for backprop fusions including drelu, dgelu, and dbias. * Support for void-C kernels and SM80 mixed-input GEMMs in the CUTLASS Python interface -## [3.2.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-10-25) +## [3.2.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2) (2023-10-25) * Minor patch for issue/1138 ## [3.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-09-22) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0fa6feaa..bec6248a 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -208,6 +208,32 @@ set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code. # CUTLASS generator cmake configuration # +# Kernel unified filter file + +set(KERNEL_FILTER_FILE "" CACHE STRING "KERNEL FILTER FILE FULL PATH") + +if (KERNEL_FILTER_FILE AND NOT CUTLASS_LIBRARY_KERNELS) + # If a kernel filter file is specified, we want to generate and then + # filter on the entire kernel set, not the default kernel + # (sub)set. The user may overried CUTLASS_LIBRRARY_KERNELS, in which + # case the resulting kernel set will be the intersection of the two + # options differenced against CUTLASS_LIBRARY_IGNORE_KERNELS. + set(CUTLASS_LIBRARY_KERNELS_INIT "*") +else() + set(CUTLASS_LIBRARY_KERNELS_INIT "") +endif() + +if (KERNEL_FILTER_FILE) + get_filename_component(KERNEL_FILTER_FILE "${KERNEL_FILTER_FILE}" ABSOLUTE) + set(KERNEL_FILTER_FILE "${KERNEL_FILTER_FILE}" CACHE STRING "KERNEL FILTER FILE FULL PATH" FORCE) +endif() + +set(SELECTED_KERNEL_LIST "selected" CACHE STRING "Name of the filtered kernel list") + +if(KERNEL_FILTER_FILE) + message(STATUS "Full path of filter file: ${KERNEL_FILTER_FILE}") +endif() + set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.") set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.") set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kernel names to exclude from build.") diff --git a/README.md b/README.md index 53688efa..64c48cab 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # CUTLASS 3.4 -_CUTLASS 3.4 - December 2023_ +_CUTLASS 3.4 - January 2024_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels @@ -48,7 +48,9 @@ CUTLASS 3.4.0 is an update to CUTLASS adding: - Improved [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) supporting {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors tuned for optimal performance on Hopper H100. - Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) utilizing TMA and Hopper H100 tensor cores now available. (Requires CUDA 12.3 or above) - Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) - commonly used in optimization of Mixture-Of-Expert models, is now available on Hopper GPUs taking advantage of TMA and Hopper H100 tensor cores. (Requires CUDA 12.3 or above) +- [Ampere Sparse GEMM](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now. - Impovements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library. +- Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved. Minimum requirements: diff --git a/bin2hex.cmake b/bin2hex.cmake index 44935f2d..b34e0284 100644 --- a/bin2hex.cmake +++ b/bin2hex.cmake @@ -1,3 +1,31 @@ +# Copyright (c) 2019 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + # A small utility function which generates a C-header from an input file function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED) FILE(READ "${FILENAME}" HEX_INPUT HEX) diff --git a/cmake/CTestTestfile.configure.cmake b/cmake/CTestTestfile.configure.cmake index 2e1e50d8..94394a50 100644 --- a/cmake/CTestTestfile.configure.cmake +++ b/cmake/CTestTestfile.configure.cmake @@ -1,3 +1,31 @@ +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + # Generated file set(TEST_SETS_SUPPORTED @TEST_SETS_SUPPORTED@) diff --git a/cmake/CTestTestfile.test.configure.cmake b/cmake/CTestTestfile.test.configure.cmake index dad2c76c..fa2ceeb9 100644 --- a/cmake/CTestTestfile.test.configure.cmake +++ b/cmake/CTestTestfile.test.configure.cmake @@ -1,3 +1,31 @@ +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + if (CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT) # The longform/extended format allows generator expressions to be # expanded property and is useful in contexts where the files need diff --git a/cmake/NvidiaCutlassPackageConfig.cmake b/cmake/NvidiaCutlassPackageConfig.cmake index bb15b1bb..364fba7a 100644 --- a/cmake/NvidiaCutlassPackageConfig.cmake +++ b/cmake/NvidiaCutlassPackageConfig.cmake @@ -1,3 +1,31 @@ +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + set(CPACK_PACKAGE_NAME NvidiaCutlass) set(CPACK_PACKAGE_VENDOR NVIDIA) set(CPACK_PACKAGE_CONTACT info@nvidia.com) diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake index a16231a1..89d64e6d 100644 --- a/cmake/googletest.cmake +++ b/cmake/googletest.cmake @@ -1,3 +1,31 @@ +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + include(FetchContent) set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against") diff --git a/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt b/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt index a20fa2b1..8b6700b7 100644 --- a/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt +++ b/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt @@ -37,4 +37,3 @@ cutlass_example_add_executable( 15_ampere_sparse_tensorop_gemm_with_visitor ampere_sparse_tensorop_gemm_with_visitor.cu ) - diff --git a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu index 972f8b69..a8e4f5fa 100644 --- a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu +++ b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu @@ -32,11 +32,9 @@ /** Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4. - Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of meta data is different for every data types. CUTLASS templates can automatically infer it based on input A and B. Check code below. - Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers efficiently. */ @@ -307,7 +305,7 @@ int run() { // uncompress tensor_a based on meta data tensor_e. We need it for reference computing. cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(), tensor_e.host_ref(), problem_size.m(), problem_size.k()); - + // Create instantiation for host reference gemm kernel cutlass::reference::host::Gemm 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) { std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl; notSupported = true; diff --git a/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h index dde3c073..7789e41f 100644 --- a/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h +++ b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h @@ -14,7 +14,7 @@ * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from - * this layernormware without specific prior written permission. + * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE diff --git a/examples/41_fused_multi_head_attention/debug_utils.h b/examples/41_fused_multi_head_attention/debug_utils.h index aafc62d6..94711458 100644 --- a/examples/41_fused_multi_head_attention/debug_utils.h +++ b/examples/41_fused_multi_head_attention/debug_utils.h @@ -12,7 +12,7 @@ * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * 3. Neither the name of the copyright holdvr nor the names of its + * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * diff --git a/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h index 2a574e71..f8f06dfe 100644 --- a/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h +++ b/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file diff --git a/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h index a5d8f8d3..d1e313d7 100644 --- a/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h +++ b/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h @@ -12,7 +12,7 @@ * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * 3. Neither the name of the copyright holdvr nor the names of its + * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * diff --git a/examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h index 2e286d3f..bc2a28c0 100644 --- a/examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h +++ b/examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file diff --git a/examples/41_fused_multi_head_attention/fmha_backward_test.py b/examples/41_fused_multi_head_attention/fmha_backward_test.py index ee0b7934..cdea9ded 100644 --- a/examples/41_fused_multi_head_attention/fmha_backward_test.py +++ b/examples/41_fused_multi_head_attention/fmha_backward_test.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + import argparse import torch import sys diff --git a/examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu b/examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu index 84662828..12e66aa1 100644 --- a/examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu +++ b/examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu @@ -12,7 +12,7 @@ * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * 3. Neither the name of the copyright holdvr nor the names of its + * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * diff --git a/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu index c4bb109d..b0fa1c98 100644 --- a/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu @@ -12,7 +12,7 @@ * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * 3. Neither the name of the copyright holdvr nor the names of its + * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * diff --git a/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu index db7e6846..2ef68451 100644 --- a/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu @@ -12,7 +12,7 @@ * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * 3. Neither the name of the copyright holdvr nor the names of its + * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * diff --git a/examples/41_fused_multi_head_attention/gemm/custom_mma.h b/examples/41_fused_multi_head_attention/gemm/custom_mma.h index 7326bad5..ee53ecc9 100644 --- a/examples/41_fused_multi_head_attention/gemm/custom_mma.h +++ b/examples/41_fused_multi_head_attention/gemm/custom_mma.h @@ -12,7 +12,7 @@ * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * 3. Neither the name of the copyright holdvr nor the names of its + * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * diff --git a/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h b/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h index 6c6d0781..be25f79c 100644 --- a/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h +++ b/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file diff --git a/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h b/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h index 5441a0a0..893f765c 100644 --- a/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h +++ b/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file diff --git a/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h b/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h index 65743645..fd527a17 100644 --- a/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h +++ b/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file diff --git a/examples/41_fused_multi_head_attention/gemm/find_default_mma.h b/examples/41_fused_multi_head_attention/gemm/find_default_mma.h index 2e6b35b6..d8d35b3b 100644 --- a/examples/41_fused_multi_head_attention/gemm/find_default_mma.h +++ b/examples/41_fused_multi_head_attention/gemm/find_default_mma.h @@ -12,7 +12,7 @@ * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * 3. Neither the name of the copyright holdvr nor the names of its + * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * diff --git a/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h b/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h index ad2b7e02..fe200a0b 100644 --- a/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h +++ b/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h @@ -12,7 +12,7 @@ * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * 3. Neither the name of the copyright holdvr nor the names of its + * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * diff --git a/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h b/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h index df510d6a..eecd8600 100644 --- a/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h +++ b/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file diff --git a/examples/41_fused_multi_head_attention/gemm_kernel_utils.h b/examples/41_fused_multi_head_attention/gemm_kernel_utils.h index 5740cab0..60e4928e 100644 --- a/examples/41_fused_multi_head_attention/gemm_kernel_utils.h +++ b/examples/41_fused_multi_head_attention/gemm_kernel_utils.h @@ -12,7 +12,7 @@ * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * 3. Neither the name of the copyright holdvr nor the names of its + * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * diff --git a/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h b/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h index 9a0885b6..3dbb0cf2 100644 --- a/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h +++ b/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file diff --git a/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h b/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h index 44f38dbc..64a58278 100644 --- a/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h +++ b/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file diff --git a/examples/41_fused_multi_head_attention/iterators/make_residual_last.h b/examples/41_fused_multi_head_attention/iterators/make_residual_last.h index e6b5d58a..7d7ad367 100644 --- a/examples/41_fused_multi_head_attention/iterators/make_residual_last.h +++ b/examples/41_fused_multi_head_attention/iterators/make_residual_last.h @@ -12,7 +12,7 @@ * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * 3. Neither the name of the copyright holdvr nor the names of its + * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * diff --git a/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h b/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h index 0f5bb84a..6bc9e52c 100644 --- a/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h +++ b/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file diff --git a/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h b/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h index 4bb96a13..4db56560 100644 --- a/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h +++ b/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file diff --git a/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h b/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h index 1784bd2e..43c13a97 100644 --- a/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h +++ b/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h @@ -12,7 +12,7 @@ * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * 3. Neither the name of the copyright holdvr nor the names of its + * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * diff --git a/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h b/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h index 7e0dc6c7..d19b1907 100644 --- a/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h +++ b/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file diff --git a/examples/41_fused_multi_head_attention/kernel_backward.h b/examples/41_fused_multi_head_attention/kernel_backward.h index b2f4ed40..b06a8a62 100644 --- a/examples/41_fused_multi_head_attention/kernel_backward.h +++ b/examples/41_fused_multi_head_attention/kernel_backward.h @@ -12,7 +12,7 @@ * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * 3. Neither the name of the copyright holdvr nor the names of its + * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h index 4abfe330..e14cf456 100644 --- a/examples/41_fused_multi_head_attention/kernel_forward.h +++ b/examples/41_fused_multi_head_attention/kernel_forward.h @@ -12,7 +12,7 @@ * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * 3. Neither the name of the copyright holdvr nor the names of its + * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * diff --git a/examples/41_fused_multi_head_attention/piped_subprocess.py b/examples/41_fused_multi_head_attention/piped_subprocess.py index 713c641b..82351f49 100644 --- a/examples/41_fused_multi_head_attention/piped_subprocess.py +++ b/examples/41_fused_multi_head_attention/piped_subprocess.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + from typing import List import torch import subprocess diff --git a/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h b/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h index 6c2d1764..4e1d6591 100644 --- a/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h +++ b/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h @@ -12,7 +12,7 @@ * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * 3. Neither the name of the copyright holdvr nor the names of its + * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * diff --git a/examples/52_hopper_gather_scatter_fusion/gather_tensor.hpp b/examples/52_hopper_gather_scatter_fusion/gather_tensor.hpp index 9caf0aa6..fbcf7f9f 100644 --- a/examples/52_hopper_gather_scatter_fusion/gather_tensor.hpp +++ b/examples/52_hopper_gather_scatter_fusion/gather_tensor.hpp @@ -101,6 +101,11 @@ struct CustomStride auto operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; } + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(CustomStride const &s, I i) { return s.func_(i) * s.stride_; } + CUTE_HOST_DEVICE friend void print(CustomStride const & s) { diff --git a/examples/55_hopper_mixed_dtype_gemm/unfused_weight_dequantize.hpp b/examples/55_hopper_mixed_dtype_gemm/unfused_weight_dequantize.hpp index 106e9897..e7b5f5a7 100644 --- a/examples/55_hopper_mixed_dtype_gemm/unfused_weight_dequantize.hpp +++ b/examples/55_hopper_mixed_dtype_gemm/unfused_weight_dequantize.hpp @@ -1,3 +1,34 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + #pragma once #include "cute/tensor.hpp" diff --git a/examples/python/00_basic_gemm.ipynb b/examples/python/00_basic_gemm.ipynb index 428d28f0..c2795551 100644 --- a/examples/python/00_basic_gemm.ipynb +++ b/examples/python/00_basic_gemm.ipynb @@ -9,7 +9,7 @@ "# Basic example of using the CUTLASS Python interface\n", "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs.\n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/python/00_basic_gemm.ipynb)\n" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/00_basic_gemm.ipynb)\n" ] }, { @@ -374,6 +374,7 @@ }, { "cell_type": "markdown", + "id": "0fff34a4", "metadata": {}, "source": [ "## Specializations for other data types\n", @@ -386,6 +387,7 @@ { "cell_type": "code", "execution_count": null, + "id": "338ad890", "metadata": {}, "outputs": [], "source": [ @@ -406,6 +408,7 @@ }, { "cell_type": "markdown", + "id": "65531df1", "metadata": {}, "source": [ "Additionally, one can run CUTLASS's FP8 GEMMs if using a frontend library capable of allocating and initializing FP8 tensors (e.g., PyTorch)" @@ -414,6 +417,7 @@ { "cell_type": "code", "execution_count": null, + "id": "776f1d8d", "metadata": {}, "outputs": [], "source": [ diff --git a/examples/python/01_epilogue.ipynb b/examples/python/01_epilogue.ipynb index a58446e4..97663f50 100644 --- a/examples/python/01_epilogue.ipynb +++ b/examples/python/01_epilogue.ipynb @@ -9,7 +9,7 @@ "# Example of using elementwise activation functions in the CUTLASS Python interface\n", "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues.\n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/python/01_epilogue.ipynb)\n" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/01_epilogue.ipynb)\n" ] }, { diff --git a/examples/python/02_pytorch_extension_grouped_gemm.ipynb b/examples/python/02_pytorch_extension_grouped_gemm.ipynb index b811c5e3..86c86fb6 100644 --- a/examples/python/02_pytorch_extension_grouped_gemm.ipynb +++ b/examples/python/02_pytorch_extension_grouped_gemm.ipynb @@ -10,7 +10,7 @@ "This notebook walks through a basic example of using the CUTLASS Python interface to declare\n", "a grouped GEMM kernel and export it as a PyTorch CUDA extension. Note that GEMM and Conv2d can also be exported as PyTorch CUDA extensions. \n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/python/02_pytorch_extension_grouped_gemm.ipynb)\n" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/02_pytorch_extension_grouped_gemm.ipynb)\n" ] }, { diff --git a/examples/python/03_basic_conv2d.ipynb b/examples/python/03_basic_conv2d.ipynb index c428319a..d0eb4526 100644 --- a/examples/python/03_basic_conv2d.ipynb +++ b/examples/python/03_basic_conv2d.ipynb @@ -8,7 +8,7 @@ "\n", "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run Conv2d. \n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/python/03_basic_conv2d.ipynb)\n" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/03_basic_conv2d.ipynb)\n" ] }, { diff --git a/examples/python/04_epilogue_visitor.ipynb b/examples/python/04_epilogue_visitor.ipynb index 5a147bcb..cf66cd24 100644 --- a/examples/python/04_epilogue_visitor.ipynb +++ b/examples/python/04_epilogue_visitor.ipynb @@ -9,7 +9,7 @@ "# Example of using epilogue visitor in the CUTLASS Python interface\n", "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues through CUTLASS Epilogue Visitor.\n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/python/04_epilogue_visitor.ipynb)\n" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/04_epilogue_visitor.ipynb)\n" ] }, { diff --git a/include/cute/algorithm/functional.hpp b/include/cute/algorithm/functional.hpp index 1fee7424..a6884657 100644 --- a/include/cute/algorithm/functional.hpp +++ b/include/cute/algorithm/functional.hpp @@ -108,6 +108,28 @@ CUTE_NAMED_UNARY_OP(conjugate, cute::conj); #undef CUTE_RIGHT_UNARY_OP #undef CUTE_NAMED_UNARY_OP +template +struct shift_right_const { + static constexpr int Shift = Shift_; + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return std::forward(arg) >> Shift; + } +}; + +template +struct shift_left_const { + static constexpr int Shift = Shift_; + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return std::forward(arg) << Shift; + } +}; + /************/ /** Binary **/ /************/ diff --git a/include/cute/algorithm/tuple_algorithms.hpp b/include/cute/algorithm/tuple_algorithms.hpp index 0912f243..7686bd06 100644 --- a/include/cute/algorithm/tuple_algorithms.hpp +++ b/include/cute/algorithm/tuple_algorithms.hpp @@ -604,8 +604,7 @@ unwrap(T const& t) } // -// Flatten a hierarchical tuple to a tuple of depth one. -// +// Flatten and Unflatten // template @@ -614,13 +613,15 @@ struct is_flat : true_type {}; template struct is_flat> : bool_constant<(true && ... && (not is_tuple::value))> {}; +// Flatten a hierarchical tuple to a tuple of depth one +// and wrap non-tuples into a rank-1 tuple. template CUTE_HOST_DEVICE constexpr auto flatten_to_tuple(T const& t) { if constexpr (is_tuple::value) { - if constexpr (is_flat::value) { + if constexpr (is_flat::value) { // Shortcut for perf return t; } else { return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); @@ -632,13 +633,15 @@ flatten_to_tuple(T const& t) CUTE_GCC_UNREACHABLE; } +// Flatten a hierarchical tuple to a tuple of depth one +// and leave non-tuple untouched. template CUTE_HOST_DEVICE constexpr auto flatten(T const& t) { if constexpr (is_tuple::value) { - if constexpr (is_flat::value) { + if constexpr (is_flat::value) { // Shortcut for perf return t; } else { return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); @@ -650,6 +653,43 @@ flatten(T const& t) CUTE_GCC_UNREACHABLE; } +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile) +{ + if constexpr (is_tuple::value) { + return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) { + auto [result, remaining_tuple] = v; + auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t); + return cute::make_tuple(append(result, sub_result), sub_tuple); + }); + } else { + return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple)); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Unflatten a flat tuple into a hierarchical tuple +// @pre flatten(@a flat_tuple) == @a flat_tuple +// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple) +// @post congruent(@a result, @a target_profile) +// @post flatten(@a result) == @a flat_tuple +template +CUTE_HOST_DEVICE constexpr +auto +unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile) +{ + auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile); + CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{}); + return unflatten_tuple; +} + // // insert and remove and replace // @@ -728,6 +768,18 @@ replace_back(T const& t, X const& x) // Make a tuple of Xs of tuple_size N // +template +CUTE_HOST_DEVICE constexpr +auto +tuple_repeat(X const& x) +{ + return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); +} + +// +// Make repeated Xs of rank N +// + template CUTE_HOST_DEVICE constexpr auto @@ -743,7 +795,7 @@ repeat(X const& x) } // -// Make a tuple of Xs the same profile as tuple +// Make a tuple of Xs the same profile as tuple T // template @@ -864,48 +916,6 @@ prepend(T const& a, X const& x) CUTE_GCC_UNREACHABLE; } -// -// Unflatten a flat tuple into a hierarchical one -// unflatten(x, flatten(x)) == x -// - -namespace detail { - -template -CUTE_HOST_DEVICE constexpr -auto -unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile) -{ - if constexpr (is_tuple::value) { - return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) { - auto [result, remaining_tuple] = v; - auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t); - return cute::make_tuple(append(result, sub_result), sub_tuple); - }); - } else { - return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple)); - } - - CUTE_GCC_UNREACHABLE; -} - -} // end namespace detail - -// @pre flatten(@a flat_tuple) == @a flat_tuple -// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple) -// @post congruent(@a result, @a target_profile) -// @post flatten(@a result) == @a flat_tuple -template -CUTE_HOST_DEVICE constexpr -auto -unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile) -{ - auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile); - CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{}); - return unflatten_tuple; -} - - // // Inclusive scan (prefix sum) // diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index 0b6d40e3..0ff2207a 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -63,7 +63,7 @@ initialize_barrier(uint64_t& smem_barrier, // 64 bits user-mange { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); - asm volatile ("mbarrier.init.shared.b64 [%0], %1;\n" + asm volatile ("mbarrier.init.shared::cta.b64 [%0], %1;\n" :: "r"(smem_int_ptr), "r"(thread_count)); #endif @@ -77,7 +77,7 @@ set_barrier_transaction_bytes(uint64_t& smem_barrier, // 64 bits user-mange { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); - asm volatile ("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;\n" + asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" :: "r"(smem_int_ptr), "r"(bytes)); #endif @@ -95,7 +95,7 @@ wait_barrier(uint64_t& smem_barrier, // 64 bits user-mange "{\n" ".reg .pred P1;\n" "LAB_WAIT:\n" - "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" "@P1 bra.uni DONE;\n" "bra.uni LAB_WAIT;\n" "DONE:\n" @@ -116,7 +116,7 @@ arrive_barrier(uint64_t& smem_barrier) // 64 bits user-mang asm volatile( "{\n" ".reg .b64 state; \n" - "mbarrier.arrive.shared.b64 state, [%0];\n" + "mbarrier.arrive.shared::cta.b64 state, [%0];\n" "}\n" :: "r"(smem_int_ptr)); #endif diff --git a/include/cute/arch/mma_sm90.hpp b/include/cute/arch/mma_sm90.hpp index 64561fa1..1cee66e0 100644 --- a/include/cute/arch/mma_sm90.hpp +++ b/include/cute/arch/mma_sm90.hpp @@ -854,11 +854,12 @@ rs_op_selector() // FP32 accumulator else if constexpr (is_same_v) { - static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); - static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); // FP16 inputs if constexpr (is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); + if constexpr (Tile_N % 256 == 0) { return SM90_64x256x16_F32F16F16_RS{}; } @@ -891,6 +892,7 @@ rs_op_selector() // BF16 inputs else if constexpr (is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x16_F32BF16BF16_RS{}; @@ -925,6 +927,7 @@ rs_op_selector() else if constexpr (is_same_v) { static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x8_F32TF32TF32_RS_TN{}; @@ -1023,7 +1026,7 @@ rs_op_selector() return SM90_64x8x32_F32E4M3E5M2_RS_TN{}; } else { - static_aRSert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 76c48c2b..ec655a66 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -65,9 +65,9 @@ struct Copy_Atom, CopyInternalType> using ValType = CopyInternalType; - using ValLayoutSrc = decltype(upcast::value>(BitLayoutSrc{})); - using ValLayoutDst = decltype(upcast::value>(BitLayoutDst{})); - using ValLayoutRef = decltype(upcast::value>(BitLayoutRef{})); + using ValLayoutSrc = decltype(recast_layout(BitLayoutSrc{})); + using ValLayoutDst = decltype(recast_layout(BitLayoutDst{})); + using ValLayoutRef = decltype(recast_layout(BitLayoutRef{})); CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType."); CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType."); @@ -479,20 +479,24 @@ make_tiled_copy(Copy_Atom const& copy_atom, ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx ValLayout const& val_layout = {}) // (m,n) -> val_idx { - constexpr int R = cute::max(rank_v, rank_v); - - auto thr_layout_mn = append(thr_layout, Layout<_1>{}); - auto val_layout_mn = append(val_layout, Layout<_1>{}); - // Take the raked_products to compute the Layout_MN - auto layout_mn = raked_product(thr_layout_mn, val_layout_mn); + // (M,N) -> (thr_idx, val_idx) + auto layout_mn = raked_product(thr_layout, val_layout); + // (thr_idx, val_idx) -> (M,N) auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); - // print("thr_layout: "); print(thr_layout_mn); print("\n"); - // print("val_layout: "); print(val_layout_mn); print("\n"); - // print("layout_mn : "); print(layout_mn); print("\n"); - // print("layout_tv : "); print(layout_tv); print("\n"); + // Tiler for extracting relevant elements + // (M,N) -> tensor coord + auto tiler = product_each(shape(layout_mn)); - return make_tiled_copy_impl(copy_atom, layout_tv, product_each(shape(layout_mn))); +#if 0 + print("thr_layout: "); print(thr_layout); print("\n"); + print("val_layout: "); print(val_layout); print("\n"); + print("layout_mn : "); print(layout_mn); print("\n"); + print("layout_tv : "); print(layout_tv); print("\n"); + print("tiler : "); print(tiler); print("\n"); +#endif + + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); } /** Produce a TiledCopy from thread and value offset maps. @@ -622,7 +626,7 @@ print(Copy_Atom, T> const&) print(" ValLayoutSrc: "); print(typename Atom::ValLayoutSrc{}); print("\n"); print(" ValLayoutDst: "); print(typename Atom::ValLayoutDst{}); print("\n"); print(" ValLayoutRef: "); print(typename Atom::ValLayoutRef{}); print("\n"); - print(" ValueType: %db\n", int(sizeof_bits::value)); + print(" ValueType: "); print(sizeof_bits::value); print("b\n"); } template @@ -755,6 +759,7 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and #include #include #include + // Config #if (__CUDACC_VER_MAJOR__ >= 12) # define CUTE_COPY_ATOM_TMA_SM90_ENABLED diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 2ae6e1e8..2e5aa59d 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -673,15 +673,14 @@ fill_tma_gmem_shape_stride(Tensor const& gtensor, / // Trivial contribution of this gmem mode to this tma mode auto ej = unwrap(get(tma_gbasis_stride)); gmem_prob_shape[i] = basis_get(ej, gmem_shape); - gmem_prob_stride[i] = basis_get(ej, gmem_stride) * sizeof_bits_v / 8; + gmem_prob_stride[i] = basis_get(ej, gmem_stride); } else { // Apply a recurrence to each gmem mode that contributes to this tma mode for_each(get(tma_gbasis_stride), [&](auto ej) { // Problem shape uint64_t shape_j = basis_get(ej, gmem_shape); // Problem stride (in bytes) - uint64_t stride_j = basis_get(ej, gmem_stride) * sizeof_bits_v / 8; - + uint64_t stride_j = basis_get(ej, gmem_stride); uint64_t old_stride = gmem_prob_stride[i]; gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j); @@ -764,8 +763,14 @@ make_tma_copy_desc(Tensor const& gtensor, // The origin assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32 // TMA descriptor does not store the zeroth stride and assumes it is 1 (TmaInternalType element). - assert(gmem_prob_stride[0] == sizeof(TmaInternalType) && "Majorness of smem doesn't match majorness of gmem"); + assert(gmem_prob_stride[0] == 1 && "Majorness of smem doesn't match majorness of gmem"); + + // convert strides to byte strides + for(uint64_t& stride : gmem_prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + // Assert the byte strides. Tma Descriptor uses byte strides assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40 assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b) assert((gmem_prob_stride[2]) < (uint64_t(1) << 40)); // Stride must be max 2^40 @@ -866,8 +871,8 @@ make_tma_copy_desc(Tensor const& gtensor, // The origin } #endif // (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) - auto recast_ratio = cute::ratio(Int::value>{}, - Int::value>{}); + auto recast_ratio = cute::trait_ratio(sizeof_bits{}, + sizeof_bits< TmaInternalType>{}); auto gbasis = make_basis_like(shape(gtensor)); @@ -943,7 +948,7 @@ make_tma_copy_atom(CopyOp, // Construct the Copy_Traits // - constexpr int num_bits_per_tma = decltype(size(tma_gbasis))::value * sizeof_bits_v; + constexpr int num_bits_per_tma = size(tma_gbasis) * sizeof_bits::value; using Traits = Copy_Traits, decltype(aux_params)>; using Atom = Copy_Atom; @@ -985,7 +990,7 @@ make_tma_copy_tiled(CopyOp const& copy_op, [[maybe_unused]] auto cta_tiler = product_each(shape(cta_v_map)); - auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / Int>{}; + auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / static_value>(); // smem idx -> smem coord auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 27456f38..045f33dc 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -55,10 +55,10 @@ struct MMA_Atom> using Traits = MMA_Traits; // Element value types from the MMA_Traits - using ValTypeD = typename Traits::ElementDVal; - using ValTypeA = typename Traits::ElementAVal; - using ValTypeB = typename Traits::ElementBVal; - using ValTypeC = typename Traits::ElementCVal; + using ValTypeD = typename Traits::ValTypeD; + using ValTypeA = typename Traits::ValTypeA; + using ValTypeB = typename Traits::ValTypeB; + using ValTypeC = typename Traits::ValTypeC; // Thr-Val layouts from the MMA_Traits using Shape_MNK = typename Traits::Shape_MNK; diff --git a/include/cute/atom/mma_traits.hpp b/include/cute/atom/mma_traits.hpp index 56145934..3940db5f 100644 --- a/include/cute/atom/mma_traits.hpp +++ b/include/cute/atom/mma_traits.hpp @@ -50,14 +50,14 @@ struct supports_output_scaling().accumulate_)>> { /** * concept MMA_Traits * { - * using ElementDVal = // Logical A-value type - * using ElementAVal = // Logical B-value type - * using ElementBVal = // Logical C-value type - * using ElementCVal = // Logical D-value type (NOTE: Not used? Assumed == ElementDVal) + * using ValTypeD = // Logical A-value type + * using ValTypeA = // Logical B-value type + * using ValTypeB = // Logical C-value type + * using ValTypeC = // Logical D-value type (NOTE: Not used? Assumed == ValTypeD) * - * using ElementAFrg = // A-type consumed by MMA (if ommitted, same as ElementAVal) - * using ElementBFrg = // B_type consumed by MMA (if ommitted, same as ElementBVal) - * using ElementCFrg = // C_type consumed by MMA (if ommitted, same as ElementCVal) + * using FrgTypeA = // A-type consumed by MMA (if ommitted, same as ValTypeA) + * using FrgTypeB = // B_type consumed by MMA (if ommitted, same as ValTypeB) + * using FrgTypeC = // C_type consumed by MMA (if ommitted, same as ValTypeC) * * using Shape_MNK = // Logical MxNxK shape of the MMA * @@ -78,10 +78,10 @@ struct MMA_Traits template struct MMA_Traits> { - using ElementDVal = D; - using ElementAVal = A; - using ElementBVal = B; - using ElementCVal = C; + using ValTypeD = D; + using ValTypeA = A; + using ValTypeB = B; + using ValTypeC = C; // Logical shape of the MMA using Shape_MNK = Shape<_1,_1,_1>; @@ -209,19 +209,19 @@ mma_unpack(MMA_Traits const& traits, namespace detail { template -struct FrgTypeA_or_Default { using type = typename X::ElementAVal; }; +struct FrgTypeA_or_Default { using type = typename X::ValTypeA; }; template -struct FrgTypeA_or_Default> { using type = typename X::ElementAFrg; }; +struct FrgTypeA_or_Default> { using type = typename X::FrgTypeA; }; template -struct FrgTypeB_or_Default { using type = typename X::ElementBVal; }; +struct FrgTypeB_or_Default { using type = typename X::ValTypeB; }; template -struct FrgTypeB_or_Default> { using type = typename X::ElementBFrg; }; +struct FrgTypeB_or_Default> { using type = typename X::FrgTypeB; }; template -struct FrgTypeC_or_Default { using type = typename X::ElementCVal; }; +struct FrgTypeC_or_Default { using type = typename X::ValTypeC; }; template -struct FrgTypeC_or_Default> { using type = typename X::ElementCFrg; }; +struct FrgTypeC_or_Default> { using type = typename X::FrgTypeC; }; } // end namespace detail diff --git a/include/cute/atom/mma_traits_sm61.hpp b/include/cute/atom/mma_traits_sm61.hpp index 85d4e987..096f0acb 100644 --- a/include/cute/atom/mma_traits_sm61.hpp +++ b/include/cute/atom/mma_traits_sm61.hpp @@ -41,10 +41,10 @@ namespace cute template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using Shape_MNK = Shape<_1,_1,_4>; using ThrID = Layout<_1>; @@ -58,10 +58,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int16_t; - using ElementBVal = int16_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int16_t; + using ValTypeB = int16_t; + using ValTypeC = int32_t; using Shape_MNK = Shape<_1,_1,_2>; using ThrID = Layout<_1>; diff --git a/include/cute/atom/mma_traits_sm70.hpp b/include/cute/atom/mma_traits_sm70.hpp index 79430350..72532f44 100644 --- a/include/cute/atom/mma_traits_sm70.hpp +++ b/include/cute/atom/mma_traits_sm70.hpp @@ -63,10 +63,10 @@ using SM70_8x8_32b = Layout,Shape <_2,_2, _2>>, template <> struct MMA_Traits { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; using Shape_MNK = Shape<_8,_8,_4>; using ThrID = SM70_QuadPair; @@ -80,10 +80,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; using Shape_MNK = Shape<_8,_8,_4>; using ThrID = SM70_QuadPair; @@ -97,10 +97,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; using Shape_MNK = Shape<_8,_8,_4>; using ThrID = SM70_QuadPair; @@ -114,10 +114,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; using Shape_MNK = Shape<_8,_8,_4>; using ThrID = SM70_QuadPair; @@ -131,10 +131,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; using Shape_MNK = Shape<_8,_8,_4>; using ThrID = SM70_QuadPair; @@ -148,10 +148,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; using Shape_MNK = Shape<_8,_8,_4>; using ThrID = SM70_QuadPair; @@ -165,10 +165,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; using Shape_MNK = Shape<_8,_8,_4>; using ThrID = SM70_QuadPair; @@ -182,10 +182,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; using Shape_MNK = Shape<_8,_8,_4>; using ThrID = SM70_QuadPair; diff --git a/include/cute/atom/mma_traits_sm75.hpp b/include/cute/atom/mma_traits_sm75.hpp index 63f83466..2f0d6ec2 100644 --- a/include/cute/atom/mma_traits_sm75.hpp +++ b/include/cute/atom/mma_traits_sm75.hpp @@ -41,10 +41,10 @@ namespace cute template <> struct MMA_Traits { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; using Shape_MNK = Shape<_16,_8,_8>; using ThrID = Layout<_32>; @@ -61,10 +61,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using Shape_MNK = Shape<_8,_8,_16>; using ThrID = Layout<_32>; diff --git a/include/cute/atom/mma_traits_sm80.hpp b/include/cute/atom/mma_traits_sm80.hpp index 6636b7aa..34740368 100644 --- a/include/cute/atom/mma_traits_sm80.hpp +++ b/include/cute/atom/mma_traits_sm80.hpp @@ -66,10 +66,10 @@ using SM80_16x8_Row = Layout,Shape < _2,_2>>, template <> struct MMA_Traits { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; using Shape_MNK = Shape<_16,_8,_8>; using ThrID = Layout<_32>; @@ -81,10 +81,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; using Shape_MNK = Shape<_16,_8,_16>; using ThrID = Layout<_32>; @@ -103,20 +103,20 @@ template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; }; template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; }; /////////////////////////////////////////////////////////////////////////////// @@ -127,20 +127,20 @@ template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; }; template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; }; /////////////////////////////////////////////////////////////////////////////// @@ -150,10 +150,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = float; - using ElementAVal = cutlass::tfloat32_t; - using ElementBVal = cutlass::tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = cutlass::tfloat32_t; + using ValTypeB = cutlass::tfloat32_t; + using ValTypeC = float; using Shape_MNK = Shape<_16,_8,_4>; using ThrID = Layout<_32>; @@ -166,10 +166,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = float; - using ElementAVal = cutlass::tfloat32_t; - using ElementBVal = cutlass::tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = cutlass::tfloat32_t; + using ValTypeB = cutlass::tfloat32_t; + using ValTypeC = float; using Shape_MNK = Shape<_16,_8,_8>; using ThrID = Layout<_32>; @@ -187,10 +187,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = double; - using ElementAVal = double; - using ElementBVal = double; - using ElementCVal = double; + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; using Shape_MNK = Shape<_8,_8,_4>; using ThrID = Layout<_32>; @@ -204,10 +204,10 @@ template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = complex; - using ElementAVal = complex; - using ElementBVal = complex; - using ElementCVal = complex; + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; }; // Custom complex fp64 MMA composed of 3 fp64 MMAs -- same layouts @@ -215,10 +215,10 @@ template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; - using ElementAVal = complex; - using ElementBVal = complex; - using ElementCVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; + using ValTypeD = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; }; /////////////////////////////////////////////////////////////////////////////// @@ -228,10 +228,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using Shape_MNK = Shape<_8,_8,_16>; using ThrID = Layout<_32>; @@ -247,10 +247,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using Shape_MNK = Shape<_16,_8,_16>; using ThrID = Layout<_32>; @@ -267,10 +267,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using Shape_MNK = Shape<_16,_8,_32>; using ThrID = Layout<_32>; @@ -293,10 +293,10 @@ template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; }; template <> @@ -307,10 +307,10 @@ template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; }; template <> @@ -321,10 +321,10 @@ template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; }; template <> @@ -339,10 +339,10 @@ template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; }; template <> @@ -353,10 +353,10 @@ template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; }; template <> @@ -367,10 +367,10 @@ template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; }; template <> @@ -385,10 +385,10 @@ template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; }; template <> @@ -399,10 +399,10 @@ template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; }; template <> @@ -413,10 +413,10 @@ template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; }; template <> @@ -430,10 +430,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = cute::uint1b_t; - using ElementBVal = cute::uint1b_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = cute::uint1b_t; + using ValTypeB = cute::uint1b_t; + using ValTypeC = int32_t; using Shape_MNK = Shape<_16,_8,_256>; using ThrID = Layout<_32>; diff --git a/include/cute/atom/mma_traits_sm90.hpp b/include/cute/atom/mma_traits_sm90.hpp index b7a12b98..fae1eaca 100644 --- a/include/cute/atom/mma_traits_sm90.hpp +++ b/include/cute/atom/mma_traits_sm90.hpp @@ -44,10 +44,10 @@ namespace cute { template <> struct MMA_Traits { - using ElementDVal = double; - using ElementAVal = double; - using ElementBVal = double; - using ElementCVal = double; + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; using Shape_MNK = Shape<_16,_8,_4>; using ThrID = Layout<_32>; @@ -62,10 +62,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = double; - using ElementAVal = double; - using ElementBVal = double; - using ElementCVal = double; + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; using Shape_MNK = Shape<_16,_8,_8>; using ThrID = Layout<_32>; @@ -80,10 +80,10 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = double; - using ElementAVal = double; - using ElementBVal = double; - using ElementCVal = double; + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; using Shape_MNK = Shape<_16,_8,_16>; using ThrID = Layout<_32>; @@ -103,30 +103,30 @@ template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = complex; - using ElementAVal = complex; - using ElementBVal = complex; - using ElementCVal = complex; + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; }; template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = complex; - using ElementAVal = complex; - using ElementBVal = complex; - using ElementCVal = complex; + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; }; template <> struct MMA_Traits : MMA_Traits { - using ElementDVal = complex; - using ElementAVal = complex; - using ElementBVal = complex; - using ElementCVal = complex; + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; }; } // end namespace cute diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index 27d40e3d..404e6550 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -426,13 +426,13 @@ using ABLayout = Layout,Int>>, template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_16>; using ThrID = Layout<_128>; @@ -448,12 +448,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_16>; using ThrID = Layout<_128>; @@ -469,13 +469,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_16>; using ThrID = Layout<_128>; @@ -491,12 +491,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_16>; using ThrID = Layout<_128>; @@ -512,13 +512,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_16>; using ThrID = Layout<_128>; @@ -534,12 +534,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_16>; using ThrID = Layout<_128>; @@ -555,13 +555,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_16>; using ThrID = Layout<_128>; @@ -577,12 +577,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_16>; using ThrID = Layout<_128>; @@ -598,13 +598,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_16>; using ThrID = Layout<_128>; @@ -620,12 +620,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_16>; using ThrID = Layout<_128>; @@ -641,13 +641,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_16>; using ThrID = Layout<_128>; @@ -663,12 +663,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_16>; using ThrID = Layout<_128>; @@ -684,13 +684,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_16>; using ThrID = Layout<_128>; @@ -706,12 +706,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_16>; using ThrID = Layout<_128>; @@ -727,13 +727,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_16>; using ThrID = Layout<_128>; @@ -749,12 +749,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_16>; using ThrID = Layout<_128>; @@ -770,13 +770,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_16>; using ThrID = Layout<_128>; @@ -792,12 +792,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_16>; using ThrID = Layout<_128>; @@ -813,13 +813,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_16>; using ThrID = Layout<_128>; @@ -835,12 +835,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_16>; using ThrID = Layout<_128>; @@ -856,13 +856,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_16>; using ThrID = Layout<_128>; @@ -878,12 +878,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_16>; using ThrID = Layout<_128>; @@ -899,13 +899,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_16>; using ThrID = Layout<_128>; @@ -921,12 +921,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_16>; using ThrID = Layout<_128>; @@ -942,13 +942,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_16>; using ThrID = Layout<_128>; @@ -964,12 +964,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_16>; using ThrID = Layout<_128>; @@ -985,13 +985,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_16>; using ThrID = Layout<_128>; @@ -1007,12 +1007,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_16>; using ThrID = Layout<_128>; @@ -1028,13 +1028,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_16>; using ThrID = Layout<_128>; @@ -1050,12 +1050,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_16>; using ThrID = Layout<_128>; @@ -1071,13 +1071,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_16>; using ThrID = Layout<_128>; @@ -1093,12 +1093,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_16>; using ThrID = Layout<_128>; @@ -1114,13 +1114,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_16>; using ThrID = Layout<_128>; @@ -1136,12 +1136,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_16>; using ThrID = Layout<_128>; @@ -1157,13 +1157,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_16>; using ThrID = Layout<_128>; @@ -1179,12 +1179,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_16>; using ThrID = Layout<_128>; @@ -1200,13 +1200,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_16>; using ThrID = Layout<_128>; @@ -1222,12 +1222,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_16>; using ThrID = Layout<_128>; @@ -1243,13 +1243,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_16>; using ThrID = Layout<_128>; @@ -1265,12 +1265,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_16>; using ThrID = Layout<_128>; @@ -1286,13 +1286,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_16>; using ThrID = Layout<_128>; @@ -1308,12 +1308,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_16>; using ThrID = Layout<_128>; @@ -1329,13 +1329,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_16>; using ThrID = Layout<_128>; @@ -1351,12 +1351,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_16>; using ThrID = Layout<_128>; @@ -1372,13 +1372,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_16>; using ThrID = Layout<_128>; @@ -1394,12 +1394,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_16>; using ThrID = Layout<_128>; @@ -1415,13 +1415,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_16>; using ThrID = Layout<_128>; @@ -1437,12 +1437,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = bfloat16_t; - using ElementBVal = bfloat16_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_16>; using ThrID = Layout<_128>; @@ -1458,13 +1458,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_8>; using ThrID = Layout<_128>; @@ -1480,12 +1480,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_8>; using ThrID = Layout<_128>; @@ -1501,13 +1501,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_8>; using ThrID = Layout<_128>; @@ -1523,12 +1523,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_8>; using ThrID = Layout<_128>; @@ -1544,13 +1544,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_8>; using ThrID = Layout<_128>; @@ -1566,12 +1566,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_8>; using ThrID = Layout<_128>; @@ -1587,13 +1587,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_8>; using ThrID = Layout<_128>; @@ -1609,12 +1609,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_8>; using ThrID = Layout<_128>; @@ -1630,13 +1630,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_8>; using ThrID = Layout<_128>; @@ -1652,12 +1652,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_8>; using ThrID = Layout<_128>; @@ -1673,13 +1673,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_8>; using ThrID = Layout<_128>; @@ -1695,12 +1695,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_8>; using ThrID = Layout<_128>; @@ -1716,13 +1716,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_8>; using ThrID = Layout<_128>; @@ -1738,12 +1738,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_8>; using ThrID = Layout<_128>; @@ -1759,13 +1759,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_8>; using ThrID = Layout<_128>; @@ -1781,12 +1781,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = tfloat32_t; - using ElementBVal = tfloat32_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_8>; using ThrID = Layout<_128>; @@ -1802,13 +1802,13 @@ struct MMA_Traits> template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -1824,13 +1824,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -1846,13 +1846,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -1868,13 +1868,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -1890,13 +1890,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -1912,13 +1912,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -1934,13 +1934,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -1956,13 +1956,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -1978,12 +1978,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -1999,12 +1999,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -2020,12 +2020,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -2041,12 +2041,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -2062,12 +2062,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -2083,12 +2083,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -2104,12 +2104,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -2125,12 +2125,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -2146,13 +2146,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -2168,13 +2168,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -2190,13 +2190,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -2212,13 +2212,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -2234,13 +2234,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -2256,13 +2256,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -2278,13 +2278,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -2300,13 +2300,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -2322,12 +2322,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -2343,12 +2343,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -2364,12 +2364,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -2385,12 +2385,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -2406,12 +2406,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -2427,12 +2427,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -2448,12 +2448,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -2469,12 +2469,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = int8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -2490,13 +2490,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -2512,13 +2512,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -2534,13 +2534,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -2556,13 +2556,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -2578,13 +2578,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -2600,13 +2600,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -2622,13 +2622,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -2644,13 +2644,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -2666,12 +2666,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -2687,12 +2687,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -2708,12 +2708,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -2729,12 +2729,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -2750,12 +2750,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -2771,12 +2771,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -2792,12 +2792,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -2813,12 +2813,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = int8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -2834,13 +2834,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -2856,13 +2856,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -2878,13 +2878,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -2900,13 +2900,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -2922,13 +2922,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -2944,13 +2944,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -2966,13 +2966,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -2988,13 +2988,13 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -3010,12 +3010,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -3031,12 +3031,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -3052,12 +3052,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -3073,12 +3073,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -3094,12 +3094,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -3115,12 +3115,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -3136,12 +3136,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -3157,12 +3157,12 @@ struct MMA_Traits template <> struct MMA_Traits { - using ElementDVal = int32_t; - using ElementAVal = uint8_t; - using ElementBVal = uint8_t; - using ElementCVal = int32_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -3178,13 +3178,13 @@ struct MMA_Traits template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -3200,12 +3200,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -3221,13 +3221,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -3243,12 +3243,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -3264,13 +3264,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -3286,12 +3286,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -3307,13 +3307,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -3329,12 +3329,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -3350,13 +3350,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -3372,12 +3372,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -3393,13 +3393,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -3415,12 +3415,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -3436,13 +3436,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -3458,12 +3458,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -3479,13 +3479,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -3501,12 +3501,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -3522,13 +3522,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -3544,12 +3544,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -3565,13 +3565,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -3587,12 +3587,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -3608,13 +3608,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -3630,12 +3630,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -3651,13 +3651,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -3673,12 +3673,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -3694,13 +3694,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -3716,12 +3716,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -3737,13 +3737,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -3759,12 +3759,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -3780,13 +3780,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -3802,12 +3802,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -3823,13 +3823,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -3845,12 +3845,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -3866,13 +3866,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -3888,12 +3888,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -3909,13 +3909,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -3931,12 +3931,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -3952,13 +3952,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -3974,12 +3974,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -3995,13 +3995,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -4017,12 +4017,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -4038,13 +4038,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -4060,12 +4060,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -4081,13 +4081,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -4103,12 +4103,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -4124,13 +4124,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -4146,12 +4146,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -4167,13 +4167,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -4189,12 +4189,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -4210,13 +4210,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -4232,12 +4232,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -4253,13 +4253,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -4275,12 +4275,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -4296,13 +4296,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -4318,12 +4318,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -4339,13 +4339,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -4361,12 +4361,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -4382,13 +4382,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -4404,12 +4404,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -4425,13 +4425,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -4447,12 +4447,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -4468,13 +4468,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -4490,12 +4490,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -4511,13 +4511,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -4533,12 +4533,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e4m3_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -4554,13 +4554,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -4576,12 +4576,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -4597,13 +4597,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -4619,12 +4619,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -4640,13 +4640,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -4662,12 +4662,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -4683,13 +4683,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -4705,12 +4705,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -4726,13 +4726,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -4748,12 +4748,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -4769,13 +4769,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -4791,12 +4791,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -4812,13 +4812,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -4834,12 +4834,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -4855,13 +4855,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -4877,12 +4877,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -4898,13 +4898,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -4920,12 +4920,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -4941,13 +4941,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -4963,12 +4963,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -4984,13 +4984,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -5006,12 +5006,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -5027,13 +5027,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -5049,12 +5049,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -5070,13 +5070,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -5092,12 +5092,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -5113,13 +5113,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -5135,12 +5135,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -5156,13 +5156,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -5178,12 +5178,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -5199,13 +5199,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -5221,12 +5221,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e4m3_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -5242,13 +5242,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -5264,12 +5264,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -5285,13 +5285,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -5307,12 +5307,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; @@ -5328,13 +5328,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -5350,12 +5350,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -5371,13 +5371,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -5393,12 +5393,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; @@ -5414,13 +5414,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -5436,12 +5436,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -5457,13 +5457,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -5479,12 +5479,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; @@ -5500,13 +5500,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -5522,12 +5522,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -5543,13 +5543,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -5565,12 +5565,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; @@ -5586,13 +5586,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -5608,12 +5608,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -5629,13 +5629,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -5651,12 +5651,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; @@ -5672,13 +5672,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -5694,12 +5694,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -5715,13 +5715,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -5737,12 +5737,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; @@ -5758,13 +5758,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -5780,12 +5780,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -5801,13 +5801,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -5823,12 +5823,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; @@ -5844,13 +5844,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -5866,12 +5866,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = half_t; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = half_t; + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -5887,13 +5887,13 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementAFrg = GMMA::smem_desc; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; @@ -5909,12 +5909,12 @@ struct MMA_Traits> template struct MMA_Traits> { - using ElementDVal = float; - using ElementAVal = float_e5m2_t; - using ElementBVal = float_e5m2_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; - using ElementBFrg = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp index 2522d6b6..ce29bdd5 100644 --- a/include/cute/int_tuple.hpp +++ b/include/cute/int_tuple.hpp @@ -479,8 +479,9 @@ weakly_congruent(IntTupleA const& a, IntTupleB const& b) template using is_weakly_congruent = decltype(weakly_congruent(declval(), declval())); -/** Test if Shape B is compatible with Shape A: - * Any coordinate into A can also be used as a coordinate into B +/** Test if Shape A is compatible with Shape B: + * the size of A and B are the same, and + * any coordinate into A can also be used as a coordinate into B * compatible is a partial order on A and B: A <= B */ template @@ -509,8 +510,8 @@ compatible(IntTupleA const& a, IntTupleB const& b) template using is_compatible = decltype(compatible(declval(), declval())); -/** Test if Shape B is weakly compatible with Shape A: - * Shape B is a multiple of a shape that is compatible with Shape A +/** Test if Shape A is weakly compatible with Shape B: + * there exists a Shape C congruent to A such that compatible(elem_scale(A,C), B) * weakly_compatible is a partial order on A and B: A <= B */ template diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 2d01fd51..df05a852 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -36,6 +36,8 @@ #include #include #include +#include +#include namespace cute { @@ -167,16 +169,6 @@ struct Layout return operator()(make_coord(c0,c1,cs...)); } - // Map a linear index to a hier ND logical coordinate - // NOTE: Dangerous and error-prone - template - CUTE_HOST_DEVICE constexpr - auto - operator[](Int const& linear_idx) const { - static_assert(is_integral::value); - return get_hier_coord(linear_idx); - } - // // Compose // @@ -305,11 +297,24 @@ struct Layout #endif }; +// Equality, return a static or dynamic boolean +template +CUTE_HOST_DEVICE constexpr +auto +operator==(Layout const& layoutA, Layout const& layoutB) +{ + return layoutA.shape() == layoutB.shape() && layoutA.stride() == layoutB.stride(); +} + template struct is_layout : false_type {}; template struct is_layout> : true_type {}; +// +// Layout construction +// template ::value || is_integral::value) && @@ -446,51 +451,59 @@ make_identity_layout(Shape const& shape) // Operations to manipulate Layouts like a tuple of pairs // +// Return the Is...th sublayout. +// For Is... = , equivalent to get(...get(get(layout))) template CUTE_HOST_DEVICE constexpr auto get(Layout const& layout) { - // Let the static_asserts in get(shape|stride) catch problems - return make_layout(get(layout.shape()), get(layout.stride())); + return make_layout(get(layout.shape()), + get(layout.stride())); } +// Return a new layout with only the modes in the range [B,E) template CUTE_HOST_DEVICE constexpr auto take(Layout const& layout) { - // Let the static_asserts in take(shape|stride) catch problems - return make_layout(take(layout.shape()), take(layout.stride())); + static_assert(B < E, "take: empty range error"); + static_assert(0 <= B && E <= Layout::rank, "take: range out of bounds"); + return make_layout(take(layout.shape()), + take(layout.stride())); } -// -// Select layout modes according to an index sequence. -// - -template +// Return a new layout with only the modes Is... = +template CUTE_HOST_DEVICE constexpr auto select(Layout const& layout) { - return make_layout(select(layout.shape()), - select(layout.stride())); + return make_layout(select(layout.shape()), + select(layout.stride())); } +// Return a layout with depth at most 1 template CUTE_HOST_DEVICE constexpr auto flatten(Layout const& layout) { - return make_layout(flatten(layout.shape()), flatten(layout.stride())); + return make_layout(flatten(layout.shape()), + flatten(layout.stride())); } +// Return a layout whose profile is congruent to TargetProfile +// @pre Input layout is flat, flatten(@a layout) == @a layout +// @pre Input layout can be folded to profile, rank(@a layout) == rank(flatten(@a target_profile)) +// @post congruent(@a result, @a target_profile) template CUTE_HOST_DEVICE constexpr auto unflatten(Layout const& layout, TargetProfile const& target_profile) { - return make_layout(unflatten(layout.shape(), target_profile), + return make_layout(unflatten(layout.shape(), target_profile), unflatten(layout.stride(), target_profile)); } @@ -498,7 +511,7 @@ unflatten(Layout const& layout, TargetProfile const& target_profil // Utilities // -// Return the layout of a mode +// Return the sublayout of mode I... template CUTE_HOST_DEVICE constexpr decltype(auto) @@ -609,17 +622,6 @@ using cosize_t = decltype(cosize(declval())); template static constexpr int cosize_v = cosize_t::value; -// Equality -// Return a static or dynamic boolean -template -CUTE_HOST_DEVICE constexpr -auto -operator==(Layout const& layoutA, Layout const& layoutB) -{ - return layoutA.shape() == layoutB.shape() && layoutA.stride() == layoutB.stride(); -} - // With crd2idx(coord, shape), makes sense to have crd2idx(coord, Layout) as well template CUTE_HOST_DEVICE constexpr @@ -762,8 +764,11 @@ bw_coalesce(OldShape const& old_shape, OldStride const& old_stride, } // end namespace detail -// Combine all the modes that are possible to combine -// Does not respect the profile of the layout, but does preserve total size +// "Simplify" the layout by combining modes that are possible to combine +// Does not respect the shape of the layout, but does preserve total size +// @post size(@a result) == size(@a layout) +// @post depth(@a result) <= 1 +// @post for all i, 0 <= i < size(@a layout), @a layout(i) == @a result(i) template CUTE_HOST_DEVICE constexpr auto @@ -894,7 +899,7 @@ group(Layout const& layout) // Composition of two layouts: lhs o rhs // @post compatible(rhs, result) // @post result(c) = lhs(rhs(c)) -// for all c in the domain of result +// for all c in the domain of rhs // namespace detail { @@ -984,19 +989,19 @@ composition(Layout const& lhs, return detail::composition_impl(lhs, rhs.shape(), rhs.stride()); } -template +template CUTE_HOST_DEVICE constexpr auto composition(Layout const& lhs, - IntTuple const& rhs) + Tiler const& rhs) { - if constexpr (is_tuple::value) { - static_assert(tuple_size::value <= Layout::rank); + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); // Drop any modes of lhs that aren't hit by rhs - return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq::value>{}, seq<>{}, seq<>{}); - } else if constexpr (is_underscore::value) { + return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq::value>{}, seq<>{}, seq<>{}); + } else if constexpr (is_underscore::value) { return lhs; - } else if constexpr (is_integral::value) { + } else if constexpr (is_integral::value) { return detail::composition_impl(lhs, rhs, Int<1>{}); } @@ -1041,19 +1046,25 @@ complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi) auto [shape, stride, result_shape, result_stride] = init; auto min_stride = cute::min(stride); auto min_idx = find(stride, min_stride); - - return cute::make_tuple(remove(shape), // Remove the min_idx from shape - remove(stride), // Remove the min_idx from stride - append(result_shape , min_stride / get(result_stride)), // new shape = min_stride / last_stride - append(result_stride, get(shape) * min_stride)); // new stride = curr_shape * min_stride + auto new_shape = min_stride / get(result_stride); + auto new_stride = get(shape) * min_stride; + static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement."); + + return cute::make_tuple(remove(shape), // Remove the min_idx from shape + remove(stride), // Remove the min_idx from stride + append(result_shape , new_shape ), // new shape = min_stride / last_stride + append(result_stride, new_stride)); // new stride = curr_shape * min_stride }); // Append the last shape mode - auto result_shape = append(result_shape_, get<0>(stride_) / get(result_stride)); // new shape = min_stride / last_stride + auto new_shape = get<0>(stride_) / get(result_stride); + static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement."); + auto result_shape = append(result_shape_, new_shape); // new shape = min_stride / last_stride // Compute the rest_shape and rest_stride auto rest_stride = get<0>(shape_) * get<0>(stride_); auto rest_shape = ceil_div(cosize_hi, rest_stride); + // Jump into coalesce and append (rest_shape, rest_stride) return detail::bw_coalesce(result_shape, result_stride, rest_shape, rest_stride); } @@ -1323,14 +1334,14 @@ zip(Layout const& layoutA, // their own mode. // -template +template CUTE_HOST_DEVICE constexpr auto tile_unzip(Layout const& layout, - IntTuple const& tile) + Tiler const& tiler) { - return make_layout(zip2_by(layout.shape(), tile), - zip2_by(layout.stride(), tile)); + return make_layout(zip2_by(layout.shape(), tiler), + zip2_by(layout.stride(), tiler)); } // @@ -1389,10 +1400,10 @@ auto tiled_divide(Layout const& layout, Tiler const& tiler) { - auto div = zipped_divide(layout, tiler); + auto result = zipped_divide(layout, tiler); - auto R = rank<1>(div); - return div(_, repeat(_)); + auto R1 = rank<1>(result); + return result(_, repeat(_)); } // Same as zipped_divide, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y) @@ -1403,40 +1414,41 @@ auto flat_divide(Layout const& layout, Tiler const& tiler) { - auto div = zipped_divide(layout, tiler); + auto result = zipped_divide(layout, tiler); - auto R0 = rank<0>(div); - auto R1 = rank<1>(div); - return div(repeat(_), repeat(_)); + auto R0 = rank<0>(result); + auto R1 = rank<1>(result); + return result(repeat(_), repeat(_)); } // // Logical product // +// @post compatible() template CUTE_HOST_DEVICE constexpr auto -logical_product(Layout const& layout, +logical_product(Layout const& block, Layout const& tiler) { - return make_layout(layout, composition(complement(layout, size(layout)*cosize(tiler)), tiler)); + return make_layout(block, composition(complement(block, size(block)*cosize(tiler)), tiler)); } template CUTE_HOST_DEVICE constexpr auto -logical_product(Layout const& layout, +logical_product(Layout const& block, Tiler const& tiler) { if constexpr (is_tuple::value) { static_assert(tuple_size::value <= Layout::rank, "logical_product: Too many modes in tiler."); - return transform_layout(layout, tiler, [](auto const& l, auto const& t) { return logical_product(l,t); }); + return transform_layout(block, tiler, [](auto const& l, auto const& t) { return logical_product(l,t); }); } else if constexpr (is_underscore::value) { - return layout; + return block; } else if constexpr (is_integral::value) { - return logical_product(layout, make_layout(tiler)); + return logical_product(block, make_layout(tiler)); } CUTE_GCC_UNREACHABLE; @@ -1452,10 +1464,10 @@ template CUTE_HOST_DEVICE constexpr auto -zipped_product(Layout const& layout, +zipped_product(Layout const& block, Tiler const& tiler) { - return tile_unzip(logical_product(layout, tiler), tiler); + return tile_unzip(logical_product(block, tiler), tiler); } // Same as zipped_product, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) @@ -1463,69 +1475,107 @@ template CUTE_HOST_DEVICE constexpr auto -tiled_product(Layout const& layout, +tiled_product(Layout const& block, Tiler const& tiler) { - auto div = zipped_product(layout, tiler); + auto result = zipped_product(block, tiler); + + auto R1 = rank<1>(result); + return result(_, repeat(_)); +} + +// Same as zipped_product, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +flat_product(Layout const& block, + Tiler const& tiler) +{ + auto result = zipped_product(block, tiler); - auto R = rank<1>(div); - return div(_, repeat(_)); + auto R0 = rank<0>(result); + auto R1 = rank<1>(result); + return result(repeat(_), repeat(_)); } -// Attempts to reproduce a layout over a tiler -// That is, think of every element of "tiler" as a "layout" -// and return the layout of the resulting structure +// +// Rank-sensitive products +// + +// blocked_product -- Reproduce a block over a tiler. +// Think of every element of "tiler" as a "block" +// and return the layout of the resulting structure. +// @post rank(@a result) == cute::max(rank(@a block), rank(@a tiler)) template CUTE_HOST_DEVICE constexpr auto -blocked_product(Layout const& layout, +blocked_product(Layout const& block, Layout const& tiler) { constexpr int R = cute::max(rank_v, rank_v); - auto result = logical_product(append(layout), append(tiler)); + auto result = logical_product(append(block), append(tiler)); - return coalesce(zip(get<0>(result), get<1>(result)), repeat(Int<1>{})); + return coalesce(zip(get<0>(result), get<1>(result)), tuple_repeat(Int<1>{})); } +// raked_product -- Reproduce a block over a tiler with block-interleaving. +// Think of every element of "tiler" as a "block", interleave those blocks, +// and return the layout of the resulting structure. +// @post rank(@a result) == cute::max(rank(@a block), rank(@a tiler)) template CUTE_HOST_DEVICE constexpr auto -raked_product(Layout const& layout, +raked_product(Layout const& block, Layout const& tiler) { constexpr int R = cute::max(rank_v, rank_v); - auto result = logical_product(append(layout), append(tiler)); + auto result = logical_product(append(block), append(tiler)); - return coalesce(zip(get<1>(result), get<0>(result)), repeat(Int<1>{})); + return coalesce(zip(get<1>(result), get<0>(result)), tuple_repeat(Int<1>{})); } +// tile_to_shape -- Perform a product of a layout so that the result matches a target shape. +// This is similar to blocked_product, but specifies the result shape instead of the +// product shape, which is more convenient in certain circumstances. +// @param block The layout to repeat +// @param trg_shape The target shape of the result +// @param ord_shape The order of the modes of @a trg_shape to tile @a layout with. +// Defaults to GenColMajor, so @a layout will repeat +// across the first mode first, the second mode second, etc +// E.g. Step<_2,_1,_3> will cause @a layout to repeat +// across the second mode first, the first mode second, and the third mode last. +// @pre rank(@a block) <= rank(@a trg_shape) +// @post compatible(@a trg_shape, shape(@a result)) template + class TrgShape, class ModeOrder = LayoutLeft> CUTE_HOST_DEVICE constexpr auto -tile_to_shape(Layout const& layout, +tile_to_shape(Layout const& block, TrgShape const& trg_shape, ModeOrder const& ord_shape = {}) { - CUTE_STATIC_ASSERT_V(rank(layout) <= rank(trg_shape), "Rank of layout must be <= rank of target shape."); + CUTE_STATIC_ASSERT_V(rank(block) <= rank(trg_shape), "Rank of layout must be <= rank of target shape."); constexpr int R = rank_v; - auto padded_layout = append(layout); + auto padded_block = append(block); - auto layout_shape = product_each(padded_layout.shape()); - auto target_shape = product_each(trg_shape); + auto block_shape = product_each(shape(padded_block)); + auto target_shape = product_each(shape(trg_shape)); // Assert proper division - CUTE_STATIC_ASSERT_V(sum(transform(target_shape, layout_shape, modulus{})) == Int<0>{}, - "Layout shape does not divide the target shape."); + if constexpr (is_static::value) { + CUTE_STATIC_ASSERT_V(weakly_compatible(block_shape, target_shape), + "tile_to_shape: block shape does not divide the target shape."); + } - auto product_shape = shape_div(target_shape, layout_shape); + auto product_shape = ceil_div(target_shape, block_shape); - return coalesce(blocked_product(padded_layout, make_ordered_layout(product_shape, ord_shape)), product_shape); + return coalesce(blocked_product(padded_block, make_ordered_layout(product_shape, ord_shape)), product_shape); } // @@ -1602,15 +1652,20 @@ CUTE_HOST_DEVICE constexpr auto recast_layout(Layout const& layout) { - if constexpr (sizeof_bits::value == sizeof_bits::value) { + using scale = decltype(trait_ratio(sizeof_bits{}, sizeof_bits{})); + if constexpr (scale::num == 1 && scale::den == 1) { return layout; - } else if constexpr (sizeof_bits::value > sizeof_bits::value) { - static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a multiple of OldType"); - return upcast::value/sizeof_bits::value>(layout); - } else if constexpr (sizeof_bits::value < sizeof_bits::value) { - static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a divisor of OldType"); - return downcast::value/sizeof_bits::value>(layout); } + else if constexpr (scale::num == 1) { + return downcast(layout); + } + else if constexpr (scale::den == 1) { + return upcast(layout); + } + else { + static_assert(dependent_false, "Recast not supported."); + } + CUTE_GCC_UNREACHABLE; } @@ -1693,12 +1748,13 @@ print_layout(Layout const& layout, ThrID const& thrid) // (m,n) -> (tid,vid) a } // Generic 2D Layout to Latex printer -- B&W 8-value color coding -template +template CUTE_HOST_DEVICE void -print_latex(Layout const& layout) // (m,n) -> idx +print_latex(LayoutA const& layout_a) { - CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{}); + auto layout = append<2>(layout_a, Layout<_1,_0>{}); char const* latex_header = "\\documentclass[convert]{standalone}\n" @@ -1727,7 +1783,6 @@ print_latex(Layout const& layout) // (m,n) -> idx for (int i = 0; i < size<0>(layout); ++i) { for (int j = 0; j < size<1>(layout); ++j) { int idx = layout(i,j); - printf("\\node[box,fill=%s] at (%d,%d) {%d};\n", color_map[idx % 8], i, j, diff --git a/include/cute/layout_composed.hpp b/include/cute/layout_composed.hpp index 69e47182..b16877f1 100644 --- a/include/cute/layout_composed.hpp +++ b/include/cute/layout_composed.hpp @@ -37,7 +37,7 @@ /* This implements a ComposedLayout of the form * LayoutA o Offset o LayoutB * and is useful in cases where composition() does not or cannot apply to LayoutA and LayoutB. - * For example, then the "divisibility condition" in shape_div is violated in composition(LayoutA, LayoutB). + * For example, when the "divisibility condition" in shape_div is violated in composition(LayoutA, LayoutB). * * This ComposedLayout provides similar functionality to Layout including tiling, partitioning, * coordinate-to-index mapping and layout manipulations, but is not considered a "normal" layout. @@ -357,12 +357,11 @@ composition(LayoutA const& layoutA, return ComposedLayout{layoutA, offset, layoutB}; } -template +template CUTE_HOST_DEVICE constexpr auto composition(ComposedLayout const& a, - LayoutOrTile const& b) + Tiler const& b) { return composition(a.layout_a(), a.offset(), composition(a.layout_b(), b)); } @@ -433,92 +432,101 @@ zip(ComposedLayout const& a) // Partitions -template +template CUTE_HOST_DEVICE constexpr auto logical_divide(ComposedLayout const& a, - Tile const& b) + Tiler const& b) { return composition(a.layout_a(), a.offset(), logical_divide(a.layout_b(), b)); } -template +template CUTE_HOST_DEVICE constexpr auto tile_unzip(ComposedLayout const& a, - Tile const& b) + Tiler const& b) { return composition(a.layout_a(), a.offset(), tile_unzip(a.layout_b(), b)); } -template +template CUTE_HOST_DEVICE constexpr auto tiled_divide(ComposedLayout const& a, - Tile const& b) + Tiler const& b) { return composition(a.layout_a(), a.offset(), tiled_divide(a.layout_b(), b)); } -template +template CUTE_HOST_DEVICE constexpr auto zipped_divide(ComposedLayout const& a, - Tile const& b) + Tiler const& b) { return composition(a.layout_a(), a.offset(), zipped_divide(a.layout_b(), b)); } -template +template CUTE_HOST_DEVICE constexpr auto flat_divide(ComposedLayout const& a, - Tile const& b) + Tiler const& b) { return composition(a.layout_a(), a.offset(), flat_divide(a.layout_b(), b)); } -template +template CUTE_HOST_DEVICE constexpr auto logical_product(ComposedLayout const& a, - Tile const& b) + Tiler const& b) { return composition(a.layout_a(), a.offset(), logical_product(a.layout_b(), b)); } -template +template +CUTE_HOST_DEVICE constexpr +auto +zipped_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), zipped_product(a.layout_b(), b)); +} + +template CUTE_HOST_DEVICE constexpr auto tiled_product(ComposedLayout const& a, - Tile const& b) + Tiler const& b) { return composition(a.layout_a(), a.offset(), tiled_product(a.layout_b(), b)); } -template +template +CUTE_HOST_DEVICE constexpr +auto +flat_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), flat_product(a.layout_b(), b)); +} + +template CUTE_HOST_DEVICE constexpr auto blocked_product(ComposedLayout const& a, - Tile const& b) + Tiler const& b) { return composition(a.layout_a(), a.offset(), blocked_product(a.layout_b(), b)); } -template +template CUTE_HOST_DEVICE constexpr auto raked_product(ComposedLayout const& a, - Tile const& b) + Tiler const& b) { return composition(a.layout_a(), a.offset(), raked_product(a.layout_b(), b)); } @@ -585,16 +593,19 @@ CUTE_HOST_DEVICE constexpr auto recast_layout(ComposedLayout const& layout) { - if constexpr (sizeof(NewType) == sizeof(OldType)) { + using scale = decltype(trait_ratio(sizeof_bits{}, sizeof_bits{})); + if constexpr (scale::num == 1 && scale::den == 1) { return layout; - } else if constexpr (sizeof(NewType) > sizeof(OldType)) { - static_assert(sizeof(NewType) % sizeof(OldType) == 0, "NewType must be a multiple of OldType"); - return upcast(layout); - } else if constexpr (sizeof(NewType) < sizeof(OldType)) { - static_assert(sizeof(OldType) % sizeof(NewType) == 0, "NewType must be a divisor of OldType"); - return downcast(layout); } - + else if constexpr (scale::num == 1) { + return downcast(layout); + } + else if constexpr (scale::den == 1) { + return upcast(layout); + } + else { + static_assert(dependent_false, "Recast not supported."); + } CUTE_GCC_UNREACHABLE; } diff --git a/include/cute/numeric/integral_constant.hpp b/include/cute/numeric/integral_constant.hpp index bd548a46..77ae6fad 100644 --- a/include/cute/numeric/integral_constant.hpp +++ b/include/cute/numeric/integral_constant.hpp @@ -413,6 +413,19 @@ conditional_return(TrueType const& t, FalseType const& f) { } } +template +CUTE_HOST_DEVICE constexpr +auto +static_value() +{ + if constexpr (is_std_integral::value) { + return Int{}; + } else { + return Trait::value; + } + CUTE_GCC_UNREACHABLE; +} + // // Display utilities // diff --git a/include/cute/numeric/integral_ratio.hpp b/include/cute/numeric/integral_ratio.hpp index 3d3eb013..bcdbf0b4 100644 --- a/include/cute/numeric/integral_ratio.hpp +++ b/include/cute/numeric/integral_ratio.hpp @@ -65,6 +65,11 @@ class R { using type = typename conditional, R>::type; }; +template +struct is_ratio : false_type {}; +template +struct is_ratio> : true_type {}; + template CUTE_HOST_DEVICE constexpr typename R::type @@ -72,6 +77,59 @@ ratio(C, C) { return {}; } +template +CUTE_HOST_DEVICE constexpr +typename R::type +ratio(C, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +ratio(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +ratio(R, R) { + return {}; +} + +// +// Non-reduced ratio implementations +// + +template +CUTE_HOST_DEVICE constexpr +R +nratio(C, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +R +nratio(C, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +R +nratio(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +R +nratio(R, R) { + return {}; +} + template CUTE_HOST_DEVICE constexpr typename R::type @@ -93,6 +151,13 @@ operator*(C, R) { return {}; } +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator/(C, R) { + return {}; +} + // Product with dynamic type needs to produce an integer... template ::value)> @@ -160,6 +225,23 @@ abs(R) { return {}; } +template +CUTE_HOST_DEVICE constexpr +auto +log_2(R) { + static_assert(R::num > 0); + static_assert(R::den > 0); + return log_2(static_cast(R::num)) - log_2(static_cast(R::den)); +} + + +template +CUTE_HOST_DEVICE constexpr +auto +trait_ratio(Trait0, Trait1) { + return nratio(static_value(), static_value()); +} + // // Display utilities // diff --git a/include/cute/numeric/math.hpp b/include/cute/numeric/math.hpp index f847594f..2674a767 100644 --- a/include/cute/numeric/math.hpp +++ b/include/cute/numeric/math.hpp @@ -310,4 +310,17 @@ safe_div(T const& t, U const& u) { return t / u; } +/** + * log2 computation + */ + +template +CUTE_HOST_DEVICE constexpr +auto +log_2(T x) { + assert(x > 0); + static_assert(is_unsigned::value, "Only to be used for unsigned integral types."); + return bit_width(x) - 1; +} + } // namespace cute diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp index 20eb79e4..f0c4e792 100644 --- a/include/cute/pointer.hpp +++ b/include/cute/pointer.hpp @@ -41,6 +41,7 @@ #include #include +#include namespace cute { diff --git a/include/cute/pointer_base.hpp b/include/cute/pointer_base.hpp index ce951b7b..8b84eba9 100644 --- a/include/cute/pointer_base.hpp +++ b/include/cute/pointer_base.hpp @@ -227,7 +227,7 @@ raw_pointer_cast(counting_iterator const& x) { template CUTE_HOST_DEVICE void print(T const* const ptr) { - printf("ptr[%db](%p)", int(sizeof_bits::value), ptr); + printf("ptr["); print(sizeof_bits::value); printf("b](%p)", ptr); } template diff --git a/include/cute/stride.hpp b/include/cute/stride.hpp index d5221339..d82af7ef 100644 --- a/include/cute/stride.hpp +++ b/include/cute/stride.hpp @@ -37,7 +37,8 @@ namespace cute { -/** crd2idx maps a coordinate within to an index +/** crd2idx(c,s,d) maps a coordinate within to an index + * * This is computed as follows: * [coord, shape, and stride are all integers => step forward by stride] * op(c, s, d) => c * d @@ -46,7 +47,6 @@ namespace cute * [coord, shape, and stride are all tuples => consider each mode independently] * op((c,C), (s,S), (d,D)) => op(c, s, d) + op((C), (S), (D)) */ - template CUTE_HOST_DEVICE constexpr auto @@ -115,10 +115,6 @@ crd2idx(Coord const& coord, CUTE_GCC_UNREACHABLE; } -// -// If we know Stride is default [CompactColMajor], then we can take shortcuts -// - namespace detail { template @@ -138,26 +134,31 @@ crd2idx_horner(CTuple const& coord, } // end namespace detail +/** crd2idx(c,s) maps a coordinate within Shape to an index + * via a colexicographical enumeration of coordinates in Shape. + * i = c0 + s0 * (c1 + s1 * (c2 + s2 * ...)) + */ template CUTE_HOST_DEVICE constexpr auto crd2idx(Coord const& coord, Shape const& shape) { - static_assert(decltype(congruent(coord,shape))::value, "Mismatched Ranks"); - if constexpr (is_tuple::value) { - // Flatten and apply Horner's method + if constexpr (is_integral::value) { // Coord is already an index + return coord; + } else if constexpr (is_integral::value) { + static_assert(dependent_false, "Invalid parameters"); + } else { // Make congruent, flatten, and apply Horner's method + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); auto flat_coord = flatten(coord); - auto flat_shape = flatten(shape); + auto flat_shape = flatten(product_like(shape, coord)); return detail::crd2idx_horner(flat_coord, flat_shape, tuple_seq{}); - } else { - return coord; } CUTE_GCC_UNREACHABLE; } -/** idx2crd splits an index to a coordinate within . +/** idx2crd(i,s,d) splits an index into a coordinate within . * * This is computed as follows: * [index, shape, and stride are all integers => determine 1D coord] @@ -170,7 +171,6 @@ crd2idx(Coord const& coord, * NOTE: This only works for compact shape+stride layouts. A more general version would * apply to all surjective layouts */ - template CUTE_HOST_DEVICE constexpr auto @@ -207,15 +207,13 @@ idx2crd(Index const& idx, CUTE_GCC_UNREACHABLE; } -// -// If we know Stride is default [CompactColMajor], then we can take shortcuts -// - -//(idx / 1) % s0 -//(idx / s0) % s1 -//(idx / (s0 * s1)) % s2 -//... - +/** idx2crd(i,s) splits an index into a coordinate within Shape + * via a colexicographical enumeration of coordinates in Shape. + * c0 = (idx / 1) % s0 + * c1 = (idx / s0) % s1 + * c2 = (idx / (s0 * s1)) % s2 + * ... + */ template CUTE_HOST_DEVICE constexpr auto diff --git a/include/cute/swizzle_layout.hpp b/include/cute/swizzle_layout.hpp index 164961b2..9651e8d2 100644 --- a/include/cute/swizzle_layout.hpp +++ b/include/cute/swizzle_layout.hpp @@ -434,15 +434,20 @@ CUTE_HOST_DEVICE constexpr auto recast_layout(Swizzle const& swizzle) { - if constexpr (sizeof_bits::value == sizeof_bits::value) { + using scale = decltype(trait_ratio(sizeof_bits{}, sizeof_bits{})); + if constexpr (scale::num == 1 && scale::den == 1) { return swizzle; - } else if constexpr (sizeof_bits::value > sizeof_bits::value) { - static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a multiple of OldType"); - return upcast::value/sizeof_bits::value>(swizzle); - } else if constexpr (sizeof_bits::value < sizeof_bits::value) { - static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a divisor of OldType"); - return downcast::value/sizeof_bits::value>(swizzle); } + else if constexpr (scale::num == 1) { + return downcast(swizzle); + } + else if constexpr (scale::den == 1) { + return upcast(swizzle); + } + else { + static_assert(dependent_false, "Recast not supported."); + } + CUTE_GCC_UNREACHABLE; } // @@ -453,7 +458,7 @@ template ,Offset,LayoutB> const& a, - Layout const& b) + Layout const& b) { auto common = max_common_layout(a.layout_b(), b); auto base = Int<(1 << M)>{}; @@ -467,7 +472,7 @@ max_common_layout(ComposedLayout,Offset,LayoutB> const& a, template CUTE_HOST_DEVICE constexpr auto -max_common_layout(Layout const& a, +max_common_layout(Layout const& a, ComposedLayout,Offset,LayoutB> const& b) { return max_common_layout(b, a); @@ -477,7 +482,7 @@ template ,Offset,LayoutB> const& a, - Layout const& b) + Layout const& b) { // This assumes that Offset is in the YZ domain of the Swizzle... return cute::min(Int<(1 << M)>{}, max_common_vector(a.layout_b(), b)); @@ -486,7 +491,7 @@ max_common_vector(ComposedLayout,Offset,LayoutB> const& a, template CUTE_HOST_DEVICE constexpr auto -max_common_vector(Layout const& a, +max_common_vector(Layout const& a, ComposedLayout,Offset,LayoutB> const& b) { return max_common_vector(b, a); @@ -517,13 +522,13 @@ template CUTE_HOST_DEVICE constexpr auto -logical_product(Layout const& block, - ComposedLayout,Offset,LayoutT> const& tile) +logical_product(Layout const& layout, + ComposedLayout,Offset,LayoutT> const& tiler) { - CUTE_STATIC_ASSERT_V(tile.offset() == Int<0>{}, "Require Swizzle offset == 0."); + CUTE_STATIC_ASSERT_V(tiler.offset() == Int<0>{}, "Require Swizzle offset == 0."); // The new layout -- if swizzle wasn't an issue, this is the result // our goal is to determine a new swizzle for these strides - auto new_layout = logical_product(block, tile.layout_b()); + auto new_layout = logical_product(layout, tiler.layout_b()); // This is accomplished by identifying // S o L :=: S? o L* @@ -536,8 +541,8 @@ logical_product(Layout const& block, auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}), make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{})); - // Compose with the tile to get the swizzle projection, P o L [The Z and Y contributing portions of L] - auto layout_only_zy = composition(swizzle_only_zy, tile.layout_b()); + // Compose with the tiler to get the swizzle projection, P o L [The Z and Y contributing portions of L] + auto layout_only_zy = composition(swizzle_only_zy, tiler.layout_b()); // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{}); // Get the Z bit and the Y bits @@ -545,8 +550,8 @@ logical_product(Layout const& block, auto active_Y = swizzle_active_bits & typename Swizzle::yyy_msk{}; // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) - auto new_active_Z = new_layout(Int<0>{}, tile.layout_b()[active_Z]); - auto new_active_Y = new_layout(Int<0>{}, tile.layout_b()[active_Y]); + auto new_active_Z = new_layout(Int<0>{}, tiler.layout_b()[active_Z]); + auto new_active_Y = new_layout(Int<0>{}, tiler.layout_b()[active_Y]); // Use this new swizzle identifier to construxt the new swizzle for new_layout // (this also makes sure it's a "valid" swizzle that Swizzle can represent) diff --git a/include/cute/util/print.hpp b/include/cute/util/print.hpp index cf75e3dd..eb3e24f3 100644 --- a/include/cute/util/print.hpp +++ b/include/cute/util/print.hpp @@ -127,6 +127,18 @@ print(unsigned long long a) { printf("%llu", a); } +CUTE_HOST_DEVICE +void +print(float a) { + printf("%f", a); +} + +CUTE_HOST_DEVICE +void +print(double a) { + printf("%f", a); +} + template CUTE_HOST_DEVICE void diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index bc0dbbe3..6698cc84 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -236,7 +236,7 @@ struct ClusterBarrier { uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); asm volatile( "{\n\t" - "mbarrier.init.shared.b64 [%1], %0; \n" + "mbarrier.init.shared::cta.b64 [%1], %0; \n" "}" : : "r"(arrive_count), "r"(smem_addr)); @@ -256,7 +256,7 @@ struct ClusterBarrier { "{\n\t" ".reg .pred P1; \n\t" "LAB_WAIT: \n\t" - "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t" "@P1 bra.uni DONE; \n\t" "bra.uni LAB_WAIT; \n\t" "DONE: \n\t" @@ -280,7 +280,7 @@ struct ClusterBarrier { ".reg .pred P1; \n\t" ".reg .pred P2; \n\t" "setp.eq.u32 P2, %3, 1;\n\t" - "@P2 mbarrier.test_wait.parity.shared.b64 P1, [%1], %2; \n\t" + "@P2 mbarrier.test_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" "selp.b32 %0, 1, 0, P1; \n\t" "}" : "=r"(waitComplete) @@ -302,7 +302,7 @@ struct ClusterBarrier { asm volatile( "{\n\t" ".reg .pred P1; \n\t" - "mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" "selp.b32 %0, 1, 0, P1; \n\t" "}" : "=r"(waitComplete) @@ -342,7 +342,7 @@ struct ClusterBarrier { uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); asm volatile( "{\n\t" - "mbarrier.arrive.shared.b64 _, [%0];\n\t" + "mbarrier.arrive.shared::cta.b64 _, [%0];\n\t" "}" : : "r"(smem_addr)); @@ -357,7 +357,7 @@ struct ClusterBarrier { uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); asm volatile( "{\n\t" - "mbarrier.ival.shared.b64 [%0]; \n\t" + "mbarrier.ival.shared::cta.b64 [%0]; \n\t" "}" : : "r"(smem_addr)); @@ -418,7 +418,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); asm volatile( "{\n\t" - "mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0; \n\t" + "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" "}" : : "r"(transaction_bytes), "r"(smem_addr)); @@ -455,7 +455,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); asm volatile( "{\n\t" - "mbarrier.expect_tx.shared.b64 [%1], %0; \n\t" + "mbarrier.expect_tx.shared::cta.b64 [%1], %0; \n\t" "}" : : "r"(transaction_bytes), "r"(smem_addr)); @@ -563,7 +563,7 @@ void cpasync_barrier_arrive(uint64_t const* smem_ptr) { uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); asm volatile( "{\n\t" - "cp.async.mbarrier.arrive.shared.b64 [%0];\n\t" + "cp.async.mbarrier.arrive.shared::cta.b64 [%0];\n\t" "}" : : "r"(smem_addr)); diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp index 28611d51..d4ef4b8e 100644 --- a/include/cutlass/cluster_launch.hpp +++ b/include/cutlass/cluster_launch.hpp @@ -77,7 +77,7 @@ struct ClusterLauncher { constexpr static int MaxClusterSize = 32; // Check for hardware compatibility - static inline __host__ + static inline CUTLASS_HOST Status check_cluster_dims(dim3 grid, dim3 cluster) { if (((cluster.x * cluster.y * cluster.z) <= MaxClusterSize) && (grid.x % cluster.x == 0) && (grid.y % cluster.y == 0) && (grid.z % cluster.z == 0)) { @@ -89,7 +89,7 @@ struct ClusterLauncher { } } - static inline __host__ + static inline CUTLASS_HOST Status #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) init(void const* kernel_function) @@ -109,7 +109,7 @@ struct ClusterLauncher { } // This is the method we expect to use going forward - static inline __host__ + static inline CUTLASS_HOST Status launch( dim3 const grid_dims, dim3 const cluster_dims, @@ -217,7 +217,7 @@ struct ClusterLaunchParams { /// kernel_ptr, x, y, z); /// @endcode template -__host__ cutlass::Status +CUTLASS_HOST cutlass::Status launch_kernel_on_cluster(const ClusterLaunchParams& params, void const* kernel_ptr, Args&& ... args) diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp index c9960bc3..1933000d 100644 --- a/include/cutlass/cuda_host_adapter.hpp +++ b/include/cutlass/cuda_host_adapter.hpp @@ -81,23 +81,59 @@ struct CudaHostAdapter { void *kernel_handles[kMaximumKernelCount]; int32_t kernel_count = 0; + // + // Methods + // + + /// Ctor CudaHostAdapter() = default; /// Dtor virtual ~CudaHostAdapter() {} - /// Copy Ctor deleted - CudaHostAdapter(const CudaHostAdapter&) = delete; + /// Copy Ctor + inline CudaHostAdapter(const CudaHostAdapter & rhs): + kernel_count(rhs.kernel_count) + { + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + } - /// Copy Assignment deleted - CudaHostAdapter& operator=(const CudaHostAdapter&) = delete; + /// Copy Assignment + inline CudaHostAdapter& operator=(const CudaHostAdapter & rhs) { - /// Move ctor deleted - CudaHostAdapter(CudaHostAdapter&&) = delete; + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + kernel_count = rhs.kernel_count; + return *this; + } - /// Move assignment deleted - CudaHostAdapter& operator=(CudaHostAdapter&&) = delete; + /// Move ctor + inline CudaHostAdapter(CudaHostAdapter && rhs): + kernel_count(rhs.kernel_count) + { + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + } + + /// Move assignment + inline CudaHostAdapter& operator=(CudaHostAdapter && rhs) { + + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + kernel_count = rhs.kernel_count; + + return *this; + } /// Ctor inline CudaHostAdapter( @@ -112,13 +148,19 @@ struct CudaHostAdapter { } } + /// Returns true if the CudaHostAdapter is empty (kernel_count == 0) + inline bool empty() const { return !kernel_count; } + + /// Returns kernel_count + inline size_t size() const { return static_cast(kernel_count); } + /// Queries the occupancy of a kernel virtual Status query_occupancy( int32_t *device_sms, int32_t *sm_occupancy, int32_t kernel_index, int32_t thread_count, - int32_t smem_size) = 0; + int32_t smem_size) const = 0; /// Launches a kernel without using Threadblock Clusters. virtual Status launch( @@ -127,7 +169,7 @@ struct CudaHostAdapter { size_t const smem_size, cudaStream_t cuda_stream, void** kernel_params, - int32_t kernel_index) = 0; + int32_t kernel_index) const = 0; /// Launches a kernel using the CUDA Extensible Launch API and Threadblock Clusters. virtual Status launch( @@ -137,7 +179,7 @@ struct CudaHostAdapter { size_t const smem_size, cudaStream_t cuda_stream, void** kernel_params, - int32_t kernel_index) = 0; + int32_t kernel_index) const = 0; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/detail/helper_macros.hpp b/include/cutlass/detail/helper_macros.hpp index 5e0ea623..55978d43 100644 --- a/include/cutlass/detail/helper_macros.hpp +++ b/include/cutlass/detail/helper_macros.hpp @@ -57,6 +57,9 @@ #define CUTLASS_DEVICE inline #endif +#define CUTLASS_HOST __host__ +#define CUTLASS_GLOBAL __global__ static + //////////////////////////////////////////////////////////////////////////////////////////////////// template diff --git a/include/cutlass/device_kernel.h b/include/cutlass/device_kernel.h index c019dfec..dc69fea0 100644 --- a/include/cutlass/device_kernel.h +++ b/include/cutlass/device_kernel.h @@ -60,7 +60,7 @@ namespace cutlass { /// Generic CUTLASS kernel template. template -__global__ +CUTLASS_GLOBAL void Kernel(typename Operator::Params params) { // Dynamic shared memory base pointer extern __shared__ int SharedStorageBase[]; @@ -76,7 +76,7 @@ void Kernel(typename Operator::Params params) { /// Generic CUTLASS kernel template. template -__global__ +CUTLASS_GLOBAL void Kernel2(typename Operator::Params params) { // Dynamic shared memory base pointer extern __shared__ int SharedStorageBase[]; @@ -96,7 +96,7 @@ void Kernel2(typename Operator::Params params) { /// Generic CUTLASS kernel template. template -__global__ +CUTLASS_GLOBAL #ifdef __CUDACC__ // Enclosing this in __CUDACC__ suppresses MSVC warnings. __launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index 8867ab9f..c25e520d 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -58,7 +58,6 @@ struct FusionOperation { static constexpr int AlignmentScalar = 0; static constexpr bool IsScaleFactorSupported = false; static constexpr bool IsPerRowScaleSupported = false; - using ElementBias = void; static constexpr int AlignmentBias = 0; static constexpr bool IsPerRowBiasSupported = false; diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h index 15cc10f4..f429dad8 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h @@ -240,8 +240,10 @@ class LinearCombinationBiasElementwise { NumericArrayConverter convert_z; frag_Z = convert_z(result_Z); - NumericArrayConverter convert_t; - frag_T = convert_t(result_T); + if constexpr (kStoreT) { + NumericArrayConverter convert_t; + frag_T = convert_t(result_T); + } } /// Applies the operation when is_source_needed() is false @@ -269,8 +271,10 @@ class LinearCombinationBiasElementwise { NumericArrayConverter convert_z; frag_Z = convert_z(result_Z); - NumericArrayConverter convert_t; - frag_T = convert_t(result_T); + if constexpr (kStoreT) { + NumericArrayConverter convert_t; + frag_T = convert_t(result_T); + } } }; diff --git a/include/cutlass/epilogue/threadblock/fusion/visitor_2x.hpp b/include/cutlass/epilogue/threadblock/fusion/visitor_2x.hpp index f5da084f..b88938dc 100644 --- a/include/cutlass/epilogue/threadblock/fusion/visitor_2x.hpp +++ b/include/cutlass/epilogue/threadblock/fusion/visitor_2x.hpp @@ -402,7 +402,7 @@ struct OutputTileThreadLayout: DefaultThreadMapTensorOp< CUTLASS_DEVICE static auto tid2coord(int thread_idx) { - return make_layout(ThreadShape{})[thread_idx]; + return cute::idx2crd(thread_idx, ThreadShape{}); } template diff --git a/include/cutlass/epilogue/threadblock/output_iterator_parameter.h b/include/cutlass/epilogue/threadblock/output_iterator_parameter.h index 8cfba768..5780623e 100644 --- a/include/cutlass/epilogue/threadblock/output_iterator_parameter.h +++ b/include/cutlass/epilogue/threadblock/output_iterator_parameter.h @@ -1,3 +1,34 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/gemm/collective/sm80_mma_multistage.hpp b/include/cutlass/gemm/collective/sm80_mma_multistage.hpp index be70ac7c..3c3b9c3b 100644 --- a/include/cutlass/gemm/collective/sm80_mma_multistage.hpp +++ b/include/cutlass/gemm/collective/sm80_mma_multistage.hpp @@ -44,7 +44,6 @@ namespace cutlass::gemm::collective { using namespace cute; - ///////////////////////////////////////////////////////////////////////////////////////////////// template < @@ -78,7 +77,8 @@ struct CollectiveMma< GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, - TransformB_> + TransformB_ + > { // // Type Aliases @@ -286,7 +286,6 @@ struct CollectiveMma< copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); } - CUTLASS_PRAGMA_NO_UNROLL for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) { @@ -332,6 +331,7 @@ struct CollectiveMma< }); } + } }; @@ -352,7 +352,8 @@ template < class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_, - class TransformB_> + class TransformB_ +> struct CollectiveMma< MainloopSm80CpAsync, TileShape_, @@ -368,7 +369,8 @@ struct CollectiveMma< GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, - TransformB_> + TransformB_ + > { // // Type Aliases @@ -627,7 +629,6 @@ struct CollectiveMma< copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); } - CUTLASS_PRAGMA_NO_UNROLL for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) { @@ -678,6 +679,7 @@ struct CollectiveMma< }); } + } }; diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp index daf07c4c..964a7f03 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -353,11 +353,9 @@ struct CollectiveMma< int thread_idx, uint32_t block_rank_in_cluster, TensorStorage& shared_tensors) { - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % 4; int lane_predicate = cute::elect_one_sync(); - if (warp_idx_in_warp_group == 0 and lane_predicate) { + if (lane_predicate) { Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) @@ -433,12 +431,10 @@ struct CollectiveMma< // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % 4; int lane_predicate = cute::elect_one_sync(); // Issue the epilogue waits - if (warp_idx_in_warp_group == 0 and lane_predicate) { + if (lane_predicate) { // This helps avoid early exit of blocks in Cluster. // Waits for all stages to either be released (all // Consumer UNLOCKs), or if the stage was never used diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp index dc0a5e9d..8a9e10b2 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp @@ -380,15 +380,10 @@ struct CollectiveMma< KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) - { - - using namespace cute; - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % 4; + TensorStorage& shared_tensors) { int lane_predicate = cute::elect_one_sync(); - if (warp_idx_in_warp_group == 0 and lane_predicate) { + if (lane_predicate) { Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) @@ -464,14 +459,11 @@ struct CollectiveMma< /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void - load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) - { - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % 4; + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { int lane_predicate = cute::elect_one_sync(); // Issue the epilogue waits - if (warp_idx_in_warp_group == 0 and lane_predicate) { + if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used @@ -494,9 +486,7 @@ struct CollectiveMma< int k_tile_count, int thread_idx, TensorStorage& shared_tensors, - Params const& mainloop_params) - { - using namespace cute; + Params const& mainloop_params) { static_assert(is_rmem::value, "C tensor must be rmem resident."); static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp index 2b8e92a2..79820b2d 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -680,9 +680,6 @@ struct CollectiveMma< int thread_idx, uint32_t block_rank_in_cluster, TensorStorage& shared_tensors) { - - using namespace cute; - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); } @@ -696,11 +693,9 @@ struct CollectiveMma< static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); } - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % 4; int lane_predicate = cute::elect_one_sync(); - if (warp_idx_in_warp_group == 0 and lane_predicate) { + if (lane_predicate) { Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) @@ -812,12 +807,10 @@ struct CollectiveMma< /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % 4; int lane_predicate = cute::elect_one_sync(); // Issue the epilogue waits - if (warp_idx_in_warp_group == 0 and lane_predicate) { + if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used @@ -841,7 +834,6 @@ struct CollectiveMma< int thread_idx, TensorStorage& shared_tensors, Params const& mainloop_params) { - using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp index 3b0336bf..61d4a4ae 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp @@ -111,6 +111,8 @@ struct CollectiveMma< using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename cutlass::PipelineState; + static constexpr int ThreadCount = CUTE_STATIC_V(size(TiledMma{})); + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp index 90552862..3d84483f 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp @@ -300,12 +300,9 @@ struct CollectiveMma< int thread_idx, uint32_t block_rank_in_cluster, TensorStorage& shared_tensors) { - using namespace cute; - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % 4; int lane_predicate = cute::elect_one_sync(); - if (warp_idx_in_warp_group == 0 and lane_predicate) { + if (lane_predicate) { Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) @@ -381,12 +378,10 @@ struct CollectiveMma< /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % 4; int lane_predicate = cute::elect_one_sync(); // Issue the epilogue waits - if (warp_idx_in_warp_group == 0 and lane_predicate) { + if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used @@ -410,8 +405,6 @@ struct CollectiveMma< int thread_idx, TensorStorage& shared_tensors, Params const& mainloop_params) { - using namespace cute; - static_assert(is_rmem::value, "C tensor must be rmem resident."); static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp index 301cb1e0..261f5396 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -297,15 +297,10 @@ struct CollectiveMma< KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) - { - - using namespace cute; - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % 4; + TensorStorage& shared_tensors) { int lane_predicate = cute::elect_one_sync(); - if (warp_idx_in_warp_group == 0 and lane_predicate) { + if (lane_predicate) { Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) @@ -382,14 +377,11 @@ struct CollectiveMma< CUTLASS_DEVICE void load_tail( MainloopPipeline pipeline, - PipelineState smem_pipe_write) - { - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % 4; + PipelineState smem_pipe_write) { int lane_predicate = cute::elect_one_sync(); // Issue the epilogue waits - if (warp_idx_in_warp_group == 0 and lane_predicate) { + if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used @@ -412,9 +404,7 @@ struct CollectiveMma< int k_tile_count, int thread_idx, TensorStorage& shared_tensors, - Params const& mainloop_params) - { - using namespace cute; + Params const& mainloop_params) { static_assert(is_rmem::value, "C tensor must be rmem resident."); static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); diff --git a/include/cutlass/gemm/device/gemm_sparse_with_visitor.h b/include/cutlass/gemm/device/gemm_sparse_with_visitor.h index 942ba1f5..d17535b6 100644 --- a/include/cutlass/gemm/device/gemm_sparse_with_visitor.h +++ b/include/cutlass/gemm/device/gemm_sparse_with_visitor.h @@ -75,7 +75,7 @@ template < /// Operator class tag typename OperatorClass_ = arch::OpClassSimt, /// Tag indicating architecture to tune for - typename ArchTag_ = arch::Sm70, + typename ArchTag_ = arch::Sm80, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape_ = typename DefaultGemmConfiguration< OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, @@ -243,7 +243,7 @@ class SparseGemmWithVisitor { /// Gets the workspace size static size_t get_workspace_size(Arguments const &args) { - + size_t bytes = 0; return bytes; @@ -271,7 +271,7 @@ class SparseGemmWithVisitor { args.ref_E.non_const_ref(), args.epilogue }; - + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); if (smem_size >= (48 << 10)) { cudaError_t result = cudaFuncSetAttribute(Kernel, @@ -324,9 +324,9 @@ class SparseGemmWithVisitor { Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - + Status status = initialize(args, workspace, stream); - + if (status == Status::kSuccess) { status = run(stream); } diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 94e7fa88..20a49f77 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -339,7 +339,10 @@ class GemmUniversalAdapter< /// Primary run() entry point API that is static allowing users to create and manage their own params. /// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments() static Status - run(Params& params, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + run(Params& params, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversal::run()"); dim3 const block = GemmKernel::get_block_shape(); dim3 const grid = get_grid_shape(params); @@ -425,7 +428,9 @@ class GemmUniversalAdapter< cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr ) { - Status status = initialize(args, workspace, stream); + + Status status = initialize(args, workspace, stream, cuda_adapter); + if (Status::kSuccess == status) { status = run(params_, stream, cuda_adapter); } @@ -444,14 +449,14 @@ class GemmUniversalAdapter< /// Overload that allows a user to re-launch the same kernel without updating internal params struct. Status - run(cudaStream_t stream = nullptr) { - return run(params_, stream); + run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + return run(params_, stream, cuda_adapter); } /// Overload that allows a user to re-launch the same kernel without updating internal params struct. Status - operator()(cudaStream_t stream = nullptr) { - return run(params_, stream); + operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + return run(params_, stream, cuda_adapter); } }; diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index 0a8f75ee..408fc0cc 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -70,6 +70,8 @@ class GemmUniversalBase { public: using GemmKernel = GemmKernel_; + + /// Boolean indicating whether the CudaHostAdapter is enabled static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; using ThreadblockShape = typename GemmKernel::Mma::Shape; @@ -99,6 +101,14 @@ class GemmUniversalBase { /// Argument structure using Arguments = typename GemmKernel::Arguments; + + /// Index of the GEMM Kernel within the CudaHostAdapter + static int32_t const kGemmKernelIndex = 0; + + /// Kernel dynamic shared memory allocation requirement + /// Update the kernel function's shared memory configuration for the current device + static constexpr size_t kSharedStorageSize = sizeof(typename GemmKernel::SharedStorage); + protected: // @@ -114,9 +124,7 @@ class GemmUniversalBase { /// Kernel SM occupancy (in thread blocks) CUTLASS_THREAD_LOCAL static int sm_occupancy_; - /// Kernel dynamic shared memory allocation requirement - /// Update the kernel function's shared memory configuration for the current device - static constexpr size_t smem_size_ = sizeof(typename GemmKernel::SharedStorage); +protected: /// Initialize static thread-local members for the thread's current device, /// if necessary. @@ -148,12 +156,12 @@ class GemmUniversalBase { } // If requires more than 48KB: configure for extended, dynamic shared memory - if constexpr (smem_size_ >= (48 << 10)) + if constexpr (kSharedStorageSize >= (48 << 10)) { cudart_result = cudaFuncSetAttribute( Kernel2, cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size_); + kSharedStorageSize); if (cudart_result != cudaSuccess) { CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; @@ -165,7 +173,7 @@ class GemmUniversalBase { &sm_occupancy_, Kernel2, GemmKernel::kThreadCount, - smem_size_, + kSharedStorageSize, cudaOccupancyDisableCachingOverride); if (cudart_result != cudaSuccess) { CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result)); @@ -179,7 +187,7 @@ class GemmUniversalBase { "device_ordinal: (" << device_ordinal_ << "), " "device_sms: (" << device_sms_ << "), " "sm_occupancy: (" << sm_occupancy_ << ") " - "smem_size: (" << smem_size_ << ") " + "smem_size: (" << kSharedStorageSize << ") " "GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")"); return Status::kSuccess; @@ -197,16 +205,58 @@ class GemmUniversalBase { /// Initialize params member - Status init_params(Arguments const &args) + Status init_params(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) { - // Initialize static device properties, if necessary - Status result = init_device_props(); - if (result != Status::kSuccess) { - return result; + int32_t device_sms = 0; + int32_t sm_occupancy = 0; + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + + // + // Occupancy query using CudaHostAdapter::query_occupancy(). + // + + if (cuda_adapter) { + + Status status = cuda_adapter->query_occupancy( + &device_sms, + &sm_occupancy, + kGemmKernelIndex, + GemmKernel::kThreadCount, + kSharedStorageSize); + + CUTLASS_ASSERT(status == Status::kSuccess); + + if (status != Status::kSuccess) { + return status; + } + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + + // Initialize static device properties, if necessary + Status result = init_device_props(); + + if (result != Status::kSuccess) { + return result; + } + + // + // Use thread-local static members for occupancy query initialized by call to + // `init_device_props()` + // + + device_sms = device_sms_; + sm_occupancy = sm_occupancy_; } // Initialize params member - params_ = typename GemmKernel::Params(args, device_sms_, sm_occupancy_); + params_ = typename GemmKernel::Params(args, device_sms, sm_occupancy); return Status::kSuccess; } @@ -217,11 +267,11 @@ class GemmUniversalBase { //--------------------------------------------------------------------------------------------- /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const &args) + static Status can_implement(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) { CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()"); - dim3 grid = get_grid_shape(args); + dim3 grid = get_grid_shape(args, cuda_adapter); if (!(grid.y <= std::numeric_limits::max() && grid.z <= std::numeric_limits::max())) @@ -235,13 +285,13 @@ class GemmUniversalBase { /// Returns the workspace size (in bytes) needed for the problem /// geometry expressed by these arguments - static size_t get_workspace_size(Arguments const &args) + static size_t get_workspace_size(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) { CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()"); // Initialize parameters from args GemmUniversalBase base; - if (base.init_params(args) != Status::kSuccess) { + if (base.init_params(args, cuda_adapter) != Status::kSuccess) { return 0; } @@ -254,13 +304,13 @@ class GemmUniversalBase { /// Returns the grid extents in thread blocks to launch - static dim3 get_grid_shape(Arguments const &args) + static dim3 get_grid_shape(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) { CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()"); // Initialize parameters from args GemmUniversalBase base; - if (base.init_params(args) != Status::kSuccess) { + if (base.init_params(args, cuda_adapter) != Status::kSuccess) { return dim3(0,0,0); } @@ -276,17 +326,48 @@ class GemmUniversalBase { /// Returns the maximum number of active thread blocks per multiprocessor - static int maximum_active_blocks() + static int maximum_active_blocks(CudaHostAdapter *cuda_adapter = nullptr) { CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); - // Initialize static device properties, if necessary - if (init_device_props() != Status::kSuccess) { - return -1; + int32_t device_sms = 0; + int32_t sm_occupancy = 0; + + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + + if (cuda_adapter) { + + Status status = cuda_adapter->query_occupancy( + &device_sms, + &sm_occupancy, + kGemmKernelIndex, + GemmKernel::kThreadCount, + kSharedStorageSize); + + CUTLASS_ASSERT(status == Status::kSuccess); + + if (status != Status::kSuccess) { + return -1; + } + } + else { + return -1; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + // Initialize static device properties, if necessary + if (init_device_props() != Status::kSuccess) { + return -1; + } + + sm_occupancy = sm_occupancy_; } CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_); - return sm_occupancy_; + return sm_occupancy; } @@ -305,7 +386,7 @@ class GemmUniversalBase { << workspace << ", stream: " << (stream ? "non-null" : "null")); // Initialize parameters from args - Status result = init_params(args); + Status result = init_params(args, cuda_adapter); if (result != Status::kSuccess) { return result; } @@ -340,13 +421,13 @@ class GemmUniversalBase { CUTLASS_TRACE_HOST(" " "grid: (" << grid << "), " "block: (" << block << "), " - "SMEM: (" << smem_size_ << ")"); + "SMEM: (" << kSharedStorageSize << ")"); if constexpr (kEnableCudaHostAdapter) { CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { void* kernel_params[] = {¶ms_}; - return cuda_adapter->launch(grid, block, smem_size_, stream, kernel_params, 0); + return cuda_adapter->launch(grid, block, kSharedStorageSize, stream, kernel_params, 0); } else { return Status::kErrorInternal; @@ -355,7 +436,7 @@ class GemmUniversalBase { else { CUTLASS_ASSERT(cuda_adapter == nullptr); - Kernel2<<>>(params_); + Kernel2<<>>(params_); // Query for errors cudaError_t result = cudaGetLastError(); @@ -370,9 +451,9 @@ class GemmUniversalBase { /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) + Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { - return run(stream); + return run(stream, cuda_adapter); } @@ -383,7 +464,7 @@ class GemmUniversalBase { cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { - Status status = initialize(args, workspace, stream); + Status status = initialize(args, workspace, stream, cuda_adapter); if (status == Status::kSuccess) { status = run(stream, cuda_adapter); diff --git a/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h b/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h index a0d602da..226b6bbf 100644 --- a/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h +++ b/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h @@ -195,4 +195,3 @@ struct DefaultSparseGemmWithVisitor -__global__ void GemmPipelined( +CUTLASS_GLOBAL void GemmPipelined( cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord grid_tiled_shape, typename Mma::IteratorA::Params params_A, diff --git a/include/cutlass/gemm/kernel/gemv_batched_strided.h b/include/cutlass/gemm/kernel/gemv_batched_strided.h index 11490daf..7ecba84a 100755 --- a/include/cutlass/gemm/kernel/gemv_batched_strided.h +++ b/include/cutlass/gemm/kernel/gemv_batched_strided.h @@ -186,7 +186,7 @@ CUTLASS_DEVICE void GemvBatchedStridedDevice( } template -__global__ void GemvBatchedStrided( +CUTLASS_GLOBAL void GemvBatchedStrided( cutlass::gemm::BatchedGemmCoord problem_size, ElementAlphaBeta alpha, ElementAlphaBeta beta, @@ -205,7 +205,7 @@ __global__ void GemvBatchedStrided( } template -__global__ void GemvBatchedStrided( +CUTLASS_GLOBAL void GemvBatchedStrided( cutlass::gemm::BatchedGemmCoord problem_size, ElementAlphaBeta alpha, typename GemvKernel::IteratorA::TensorRef ref_A, @@ -221,7 +221,7 @@ __global__ void GemvBatchedStrided( } template -__global__ void GemvBatchedStrided( +CUTLASS_GLOBAL void GemvBatchedStrided( cutlass::gemm::BatchedGemmCoord problem_size, typename GemvKernel::IteratorA::TensorRef ref_A, typename GemvKernel::IteratorA::TensorRef::LongIndex lda, diff --git a/include/cutlass/gemm/kernel/sm70_gemm.hpp b/include/cutlass/gemm/kernel/sm70_gemm.hpp index e5fe6ec5..aecf3758 100644 --- a/include/cutlass/gemm/kernel/sm70_gemm.hpp +++ b/include/cutlass/gemm/kernel/sm70_gemm.hpp @@ -59,7 +59,6 @@ class GemmUniversal< // Type Aliases // using ProblemShape = ProblemShape_; - static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); @@ -77,13 +76,14 @@ class GemmUniversal< using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; - static_assert(cute::is_void_v or cute::is_same_v, - "SM70 kernel does not support specializing the tile scheduler."); using TileSchedulerTag = TileScheduler_; using TileScheduler = typename detail::TileSchedulerSelector< TileScheduler_, ArchTag, TileShape, cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; + static constexpr bool is_valid_tile_scheduler = + cute::is_void_v or cute::is_same_v; +static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializing the tile scheduler."); // Epilogue derived types using CollectiveEpilogue = CollectiveEpilogue_; @@ -131,6 +131,10 @@ class GemmUniversal< Params to_underlying_arguments(Arguments const& args, void* workspace) { (void) workspace; + + KernelHardwareInfo hw_info{args.hw_info.device_id, args.hw_info.sm_count}; + auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); + return { args.mode, args.problem_shape, @@ -148,13 +152,16 @@ class GemmUniversal< static int get_workspace_size(Arguments const& args) { - return 0; + int workspace_size = 0; + return workspace_size; } static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { - return Status::kSuccess; + cutlass::Status status = Status::kSuccess; + + return status; } static dim3 diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 28ac4a0e..8c81ae5d 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -45,7 +45,6 @@ #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" #include "cutlass/trace.h" - /////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel { @@ -74,7 +73,6 @@ class GemmUniversal< using ProblemShape = ProblemShape_; static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, "ProblemShape{} should be or "); - // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape = typename CollectiveMainloop::TileShape; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp index 67f23afa..72fb5149 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp @@ -40,7 +40,6 @@ #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" #include "cutlass/trace.h" - #include "cute/tensor.hpp" /////////////////////////////////////////////////////////////////////////////// @@ -82,7 +81,6 @@ class GemmUniversal< using ProblemShape = ProblemShape_; static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); - // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape = typename CollectiveMainloop::TileShape; @@ -121,7 +119,8 @@ class GemmUniversal< sizeof(typename CollectiveMainloop::SharedStorage), sizeof(typename CollectiveEpilogue::SharedStorage))); - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})); + static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::ThreadCount; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; // Device side arguments diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp index 582beee9..8db1306a 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp @@ -44,7 +44,6 @@ #include "cutlass/trace.h" #include "cute/tensor.hpp" - /////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel { @@ -71,7 +70,6 @@ class GemmUniversal< using ProblemShape = ProblemShape_; static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); - // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape = typename CollectiveMainloop::TileShape; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 25d7711c..c206dc4b 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -44,7 +44,6 @@ #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" #include "cutlass/trace.h" - /////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel { @@ -71,7 +70,6 @@ class GemmUniversal< using ProblemShape = ProblemShape_; static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); - // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape = typename CollectiveMainloop::TileShape; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index a48f218c..b39c44d4 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -45,7 +45,6 @@ #include "cutlass/trace.h" #include "cute/tensor.hpp" - /////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel { @@ -72,7 +71,6 @@ class GemmUniversal< using ProblemShape = ProblemShape_; static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); - // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape = typename CollectiveMainloop::TileShape; @@ -521,10 +519,10 @@ class GemmUniversal< shared_storage.tensors.epilogue ); - // Get next work tile - scheduler.advance_to_next_work(); - work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp index c43b50bc..c0b3a236 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp @@ -42,7 +42,6 @@ #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" - /////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel { @@ -69,7 +68,6 @@ class GemmUniversal< using ProblemShape = ProblemShape_; static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); - // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape = typename CollectiveMainloop::TileShape; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp index 403b24d1..38e42ecf 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp @@ -44,7 +44,6 @@ #include "cutlass/trace.h" #include "cute/tensor.hpp" - /////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel { diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp index c9551ec1..77290d3a 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp @@ -29,165 +29,24 @@ * **************************************************************************************************/ #pragma once +#include "cutlass/gemm/kernel/static_tile_scheduler.hpp" -#include "cutlass/fast_math.h" -#include "cutlass/gemm_coord.hpp" -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/gemm/kernel/tile_scheduler_params.h" -#include "cute/layout.hpp" -#include "cute/tensor.hpp" -#include "cute/arch/cluster_sm90.hpp" namespace cutlass::gemm::kernel::detail { /////////////////////////////////////////////////////////////////////////////// // Persistent Thread Block (TB) scheduler -class PersistentTileSchedulerSm90 { - // - // Data members - // - -private: - uint64_t current_work_linear_idx_; - uint64_t total_grid_size_; +class PersistentTileSchedulerSm90: +public StaticPersistentTileScheduler { + using BaseScheduler = StaticPersistentTileScheduler; public: - struct WorkTileInfo { - int32_t M_idx = 0; - int32_t N_idx = 0; - int32_t L_idx = 0; - bool is_valid_tile = false; - - CUTLASS_HOST_DEVICE - bool - is_valid() const { - return is_valid_tile; - } - - CUTLASS_HOST_DEVICE - static WorkTileInfo - invalid_work_tile() { - return {-1, -1, -1, false}; - } - - CUTLASS_HOST_DEVICE - bool - is_final_split(uint32_t k_tiles_per_output_tile) const { - return true; - } - - CUTLASS_HOST_DEVICE - int32_t - reduction_subtile_idx() const { - return -1; - } - }; - + using StaticPersistentTileScheduler::StaticPersistentTileScheduler; using Params = PersistentTileSchedulerSm90Params; using RasterOrder = typename Params::RasterOrder; using RasterOrderOptions = typename Params::RasterOrderOptions; - struct Arguments { - int max_swizzle_size = 1; - RasterOrderOptions raster_order = RasterOrderOptions::Heuristic; - }; - - // Sink scheduler params as a member - Params scheduler_params; - - // - // Methods - // - - template - static Params - to_underlying_arguments( - ProblemShapeMNKL problem_shape_mnkl, - TileShape tile_shape, - ClusterShape cluster_shape, - [[maybe_unused]] KernelHardwareInfo const& hw_info, - Arguments const& arguments, - [[maybe_unused]] void* workspace=nullptr, - [[maybe_unused]] const uint32_t epilogue_subtile = 1) { - - // We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic - static_assert(cute::is_static::value); - static_assert(cute::is_static::value); - - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape); - - Params params; - params.initialize( - problem_blocks, - to_gemm_coord(cluster_shape), - hw_info, - arguments.max_swizzle_size, - arguments.raster_order - ); - - return params; - } - - CUTLASS_HOST_DEVICE - static bool - can_implement(Arguments const& args) { - return true; - } - - CUTLASS_HOST_DEVICE - PersistentTileSchedulerSm90() { }; - - CUTLASS_DEVICE explicit PersistentTileSchedulerSm90(Params const& params_) : scheduler_params(params_) { - // MSVC requires protecting use of CUDA-specific nonstandard syntax, - // like blockIdx and gridDim, with __CUDA_ARCH__. -#if defined(__CUDA_ARCH__) - if (params_.raster_order_ == RasterOrder::AlongN) { - current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x); - } - else { - current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y); - } - - total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z); -#else - CUTLASS_ASSERT(false && "This line should never be reached"); -#endif - } - - CUTLASS_DEVICE - WorkTileInfo - get_current_work() const { - return get_current_work_for_linear_idx(current_work_linear_idx_); - } - - CUTLASS_DEVICE - WorkTileInfo - get_current_work_for_linear_idx(uint64_t linear_idx) const { - if (linear_idx >= scheduler_params.blocks_per_problem_) { - return WorkTileInfo::invalid_work_tile(); - } - - // Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices - uint64_t work_idx_l, remainder; - scheduler_params.divmod_batch_(work_idx_l, remainder, linear_idx); - - uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(remainder); - - auto [work_idx_m, work_idx_n] = get_work_idx_m_and_n(blk_per_grid_dim, - scheduler_params.divmod_cluster_shape_major_, - scheduler_params.divmod_cluster_shape_minor_, - scheduler_params.divmod_cluster_blk_major_, - scheduler_params.log_swizzle_size_, - scheduler_params.raster_order_); - - return {work_idx_m, work_idx_n, static_cast(work_idx_l), true}; - } - - CUTLASS_DEVICE - void - advance_to_next_work(uint32_t advance_count = 1) { - current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); - } + using Arguments = BaseScheduler::Arguments; // get work_idx_m, work_idx_n from blk_per_grid_dim while applying swizzle static CUTLASS_DEVICE @@ -236,111 +95,6 @@ class PersistentTileSchedulerSm90 { } - // Computes the linear index within a batch given M and N tile offsets within the batch. - // This essentially inverts the mapping performed in get_work_idx_m_and_n - static CUTLASS_DEVICE - uint64_t - get_linear_idx_from_m_and_n( - int32_t tile_m, - int32_t tile_n, - FastDivmodU64Pow2 const& divmod_cluster_shape_major, - FastDivmodU64Pow2 const& divmod_cluster_shape_minor, - FastDivmodU64 const& divmod_cluster_blk_major, - int32_t log_swizzle_size, - RasterOrder raster_order) { - - auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); - - uint64_t minor_work_idx, major_work_idx, cluster_minor_offset; - if (raster_order == RasterOrder::AlongN) { - minor_work_idx = static_cast(tile_m); - major_work_idx = static_cast(tile_n); - cluster_minor_offset = cta_m_in_cluster; - } - else { - major_work_idx = static_cast(tile_m); - minor_work_idx = static_cast(tile_n); - cluster_minor_offset = cta_n_in_cluster; - } - - uint64_t cluster_idx_minor, cluster_idx_major, cluster_major_offset; - cluster_idx_minor = divmod_cluster_shape_minor.divide(minor_work_idx - cluster_minor_offset); - divmod_cluster_shape_major(cluster_idx_major, cluster_major_offset, major_work_idx); - - uint64_t cluster_idx_minor_div_swizzle = cluster_idx_minor >> log_swizzle_size; - uint64_t offset = cluster_idx_minor & ((1 << log_swizzle_size) - 1); - - uint64_t extra = cluster_idx_minor_div_swizzle * divmod_cluster_blk_major.divisor + cluster_idx_major; - - uint64_t cluster_id = (extra << log_swizzle_size) | offset; - return (cluster_id * divmod_cluster_shape_major.divisor + cluster_major_offset) * divmod_cluster_shape_minor.divisor + cluster_minor_offset; - } - - // Given the inputs, computes the total number of output blocks this problem will compute over - // Note that this is only the logical size of our grid, not the physical grid we will actually launch. - template - CUTLASS_HOST_DEVICE static - dim3 - get_tiled_cta_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape cta_shape, ClusterShape cluster_shape) { - auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(cta_shape))); - auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(cta_shape))); - - return Params::get_tiled_cta_shape_mnl( - to_gemm_coord(problem_shape_mnkl), - to_gemm_coord(cluster_shape), - cta_m, cta_n - ); - } - - // Given the inputs, computes the physical grid we should launch. - template - CUTLASS_HOST_DEVICE static - dim3 - get_grid_shape( - ProblemShapeMNKL problem_shape_mnk, - BlockShape cta_shape, - ClusterShape cluster_shape, - KernelHardwareInfo hw_info, - Arguments arguments, - bool truncate_by_problem_size=true) { - - auto problem_shape_mnkl = cute::append<4>(problem_shape_mnk, cute::Int<1>{}); - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape); - - return Params::get_grid_shape( - problem_blocks, - to_gemm_coord(cluster_shape), - hw_info, - arguments.max_swizzle_size, - arguments.raster_order, - /* truncate_by_problem_size = */true - ); - } - - // Returns whether the block assigned this work should compute the epilogue for the corresponding - // output tile. For the basic tile scheduler, this is always true. - CUTLASS_HOST_DEVICE - static bool - compute_epilogue(WorkTileInfo const&, Params const&) { - return true; - } - - // Performs the reduction across splits for a given output tile. Since this scheduler does - // not split output tiles, no reduction is needed. - template - CUTLASS_DEVICE - static void - fixup(Params const&, WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) {} - - // Returns whether the current WorkTileInfo passed in should continue to be used. Since - // this scheduler only schedules work in units of single, full output tiles, the WorkTileInfo - // passed in should not be used after having been processed. - CUTLASS_DEVICE - static bool - continue_current_work(WorkTileInfo&) { - return false; - } - // The basic tile scheduler does not require any additional workspace template static int @@ -355,74 +109,6 @@ class PersistentTileSchedulerSm90 { return Status::kSuccess; } - template - CUTLASS_HOST_DEVICE - static int - get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape problem_shape, TileShape tile_shape) { - // All work units returned by this scheduler cover the entire K iteration - // space of the output tile assigned to the work unit. - return cute::size(cute::ceil_div(cute::get<2>(problem_shape), cute::get<2>(tile_shape))); - } - - CUTLASS_HOST_DEVICE - static uint32_t - get_work_k_tile_start(WorkTileInfo const&) { - // All work units returned by this scheduler start from K tile 0 - return 0u; - } - - CUTLASS_DEVICE - static bool - need_separate_reduction(Params const& params) { - return false; - } - - CUTLASS_DEVICE - bool - is_work_tile_for_reduction(WorkTileInfo const& work_tile_info, Params const& params) { - return false; - } - - CUTLASS_DEVICE - uint32_t - epilgoue_subtile_idx(WorkTileInfo const& work_tile_info, Params const& params) const { - return 0; - } - - template - CUTLASS_DEVICE - void - separate_reduction( - Params const& params, - WorkTileInfo const& work_tile_info, - FrgTensorC& accumulators, - uint32_t num_barriers, - uint32_t barrier_idx) { - } - - // Shares the accumulator set with peers in the global workspace - template - CUTLASS_DEVICE - static void - share( - Params const& params, - WorkTileInfo const& work_tile_info, - FrgTensorC& accumulators, - uint32_t num_barriers, - uint32_t barrier_idx) { - } - - CUTLASS_DEVICE - static bool - valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) { - return true; - } - - CUTLASS_DEVICE - static bool - requires_separate_reduction(Params const& params) { - return false; - } }; -} // namespace cutlass::gemm::kernel::detail +} diff --git a/include/cutlass/gemm/kernel/sparse_gemm.h b/include/cutlass/gemm/kernel/sparse_gemm.h index c87f2098..cd01b0f8 100644 --- a/include/cutlass/gemm/kernel/sparse_gemm.h +++ b/include/cutlass/gemm/kernel/sparse_gemm.h @@ -94,6 +94,7 @@ struct SparseGemm { // // Data members // + typename Epilogue::OutputTileIterator::Params params_C; typename Epilogue::OutputTileIterator::TensorRef ref_C; typename Epilogue::OutputTileIterator::Params params_D; @@ -125,8 +126,8 @@ struct SparseGemm { ref_C(ref_C), params_D(ref_D.layout()), ref_D(ref_D), - output_op(output_op), - semaphore(workspace) { + output_op(output_op) { + semaphore = workspace; } }; diff --git a/include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h b/include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h index 990a6c36..e34d0f83 100644 --- a/include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h +++ b/include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h @@ -1,3 +1,4 @@ + /*************************************************************************************************** * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause diff --git a/include/cutlass/gemm/kernel/static_tile_scheduler.hpp b/include/cutlass/gemm/kernel/static_tile_scheduler.hpp new file mode 100644 index 00000000..9b5fd15f --- /dev/null +++ b/include/cutlass/gemm/kernel/static_tile_scheduler.hpp @@ -0,0 +1,453 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/fast_math.h" +#include "cutlass/gemm_coord.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/pipeline/pipeline.hpp" +namespace cutlass::gemm::kernel::detail { + +/////////////////////////////////////////////////////////////////////////////// + +// Users are not supposed to use this class directly. +// This is a CRTP base class for the actual tile schedulers. +template +class StaticPersistentTileScheduler { + // + // Data members + // + +private: + uint64_t current_work_linear_idx_; + uint64_t total_grid_size_; + +public: + struct WorkTileInfo { + int32_t M_idx = 0; + int32_t N_idx = 0; + int32_t L_idx = 0; + bool is_valid_tile = false; + + CUTLASS_HOST_DEVICE + bool + is_valid() const { + return is_valid_tile; + } + + CUTLASS_HOST_DEVICE + static WorkTileInfo + invalid_work_tile() { + return {-1, -1, -1, false}; + } + + CUTLASS_HOST_DEVICE + bool + is_final_split(uint32_t k_tiles_per_output_tile) const { + return true; + } + + CUTLASS_HOST_DEVICE + int32_t + reduction_subtile_idx() const { + return -1; + } + }; + + using Params = PersistentTileSchedulerSm90Params; + using RasterOrder = typename Params::RasterOrder; + using RasterOrderOptions = typename Params::RasterOrderOptions; +public: + struct Arguments { + int max_swizzle_size = 1; + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic; + }; + + template + static Params + to_underlying_arguments( + ProblemShapeMNKL problem_shape_mnkl, + TileShape tile_shape, + ClusterShape cluster_shape, + [[maybe_unused]] KernelHardwareInfo const& hw_info, + Arguments const& arguments, + [[maybe_unused]] void* workspace=nullptr, + [[maybe_unused]] const uint32_t epilogue_subtile = 1) { + + // We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic + static_assert(cute::is_static::value); + static_assert(cute::is_static::value); + + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape); + + Params params; + params.initialize( + problem_blocks, + to_gemm_coord(cluster_shape), + hw_info, + arguments.max_swizzle_size, + arguments.raster_order + ); + + return params; + } + + CUTLASS_HOST_DEVICE + static bool + can_implement(Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + StaticPersistentTileScheduler() { } + + CUTLASS_DEVICE explicit StaticPersistentTileScheduler(Params const& params_) : scheduler_params(params_) { + // MSVC requires protecting use of CUDA-specific nonstandard syntax, + // like blockIdx and gridDim, with __CUDA_ARCH__. +#if defined(__CUDA_ARCH__) + if (params_.raster_order_ == RasterOrder::AlongN) { + current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x); + } + else { + current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y); + } + + total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z); +#else + CUTLASS_ASSERT(false && "This line should never be reached"); +#endif + } + + // Returns the initial work tile info that will be computed over + template + CUTLASS_DEVICE + WorkTileInfo + initial_work_tile_info(ClusterShape cluster_shape) { + return get_current_work(); + } + + CUTLASS_DEVICE + WorkTileInfo + get_current_work() const { + return get_current_work_for_linear_idx(current_work_linear_idx_); + } + + CUTLASS_DEVICE + WorkTileInfo + get_current_work_for_linear_idx(uint64_t linear_idx) const { + if (linear_idx >= scheduler_params.blocks_per_problem_) { + return WorkTileInfo::invalid_work_tile(); + } + + // Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices + uint64_t work_idx_l, remainder; + scheduler_params.divmod_batch_(work_idx_l, remainder, linear_idx); + + uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(remainder); + + auto [work_idx_m, work_idx_n] = Subclass::get_work_idx_m_and_n(blk_per_grid_dim, + scheduler_params.divmod_cluster_shape_major_, + scheduler_params.divmod_cluster_shape_minor_, + scheduler_params.divmod_cluster_blk_major_, + scheduler_params.log_swizzle_size_, + scheduler_params.raster_order_); + + return {work_idx_m, work_idx_n, static_cast(work_idx_l), true}; + } + + CUTLASS_DEVICE + void + advance_to_next_work(uint32_t advance_count = 1) { + current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); + } + + // Computes the linear index within a batch given M and N tile offsets within the batch. + // This essentially inverts the mapping performed in get_work_idx_m_and_n + static CUTLASS_DEVICE + uint64_t + get_linear_idx_from_m_and_n( + int32_t tile_m, + int32_t tile_n, + FastDivmodU64Pow2 const& divmod_cluster_shape_major, + FastDivmodU64Pow2 const& divmod_cluster_shape_minor, + FastDivmodU64 const& divmod_cluster_blk_major, + int32_t log_swizzle_size, + RasterOrder raster_order) { + + auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); + + uint64_t minor_work_idx, major_work_idx, cluster_minor_offset; + if (raster_order == RasterOrder::AlongN) { + minor_work_idx = static_cast(tile_m); + major_work_idx = static_cast(tile_n); + cluster_minor_offset = cta_m_in_cluster; + } + else { + major_work_idx = static_cast(tile_m); + minor_work_idx = static_cast(tile_n); + cluster_minor_offset = cta_n_in_cluster; + } + + uint64_t cluster_idx_minor, cluster_idx_major, cluster_major_offset; + cluster_idx_minor = divmod_cluster_shape_minor.divide(minor_work_idx - cluster_minor_offset); + divmod_cluster_shape_major(cluster_idx_major, cluster_major_offset, major_work_idx); + + uint64_t cluster_idx_minor_div_swizzle = cluster_idx_minor >> log_swizzle_size; + uint64_t offset = cluster_idx_minor & ((1 << log_swizzle_size) - 1); + + uint64_t extra = cluster_idx_minor_div_swizzle * divmod_cluster_blk_major.divisor + cluster_idx_major; + + uint64_t cluster_id = (extra << log_swizzle_size) | offset; + return (cluster_id * divmod_cluster_shape_major.divisor + cluster_major_offset) * divmod_cluster_shape_minor.divisor + cluster_minor_offset; + } + + // Given the inputs, computes the total number of output blocks over which this problem will compute. + // Note that this is only the logical size of our grid, not the physical grid we will actually launch. + template + CUTLASS_HOST_DEVICE static + dim3 + get_tiled_cta_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape cta_shape, ClusterShape cluster_shape) { + auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(cta_shape))); + auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(cta_shape))); + + return Params::get_tiled_cta_shape_mnl( + to_gemm_coord(problem_shape_mnkl), + to_gemm_coord(cluster_shape), + cta_m, cta_n + ); + } + // Kernel helper function to get next work ID + template + CUTLASS_DEVICE + auto + fetch_next_work( + WorkTileInfo work_tile_info, + WorkIdPipeline& work_id_pipeline, + WorkIdPipelineState work_id_pipe_consumer_state) { + WorkTileInfo new_work_tile_info; + advance_to_next_work(); + new_work_tile_info = get_current_work(); + + // Return true to indicate that the WorkID pipeline state should be advanced + return cute::make_tuple(new_work_tile_info, true); + } + + CUTLASS_DEVICE + static auto + work_tile_to_cta_coord(WorkTileInfo work_tile_info) { + // Get every cta coord in three dimensions of the cluster + auto [cta_m_in_cluster, cta_n_in_cluster, cta_l_in_cluster] = cute::block_id_in_cluster(); + return make_coord( + work_tile_info.M_idx + static_cast(cta_m_in_cluster), + work_tile_info.N_idx + static_cast(cta_n_in_cluster), + _, + work_tile_info.L_idx + static_cast(cta_l_in_cluster) + ); + } + + // Given the inputs, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE static + dim3 + get_grid_shape( + ProblemShapeMNKL problem_shape_mnk, + BlockShape cta_shape, + ClusterShape cluster_shape, + KernelHardwareInfo hw_info, + Arguments arguments, + bool truncate_by_problem_size=true) { + + auto problem_shape_mnkl = cute::append<4>(problem_shape_mnk, cute::Int<1>{}); + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape); + + return Params::get_grid_shape( + problem_blocks, + to_gemm_coord(cluster_shape), + hw_info, + arguments.max_swizzle_size, + arguments.raster_order, + /* truncate_by_problem_size = */true + ); + } + + // Given the inputs, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE static + dim3 + get_grid_shape( + Params const& params, + ProblemShapeMNKL problem_shape_mnk, + BlockShape cta_shape, + ClusterShape cluster_shape, + KernelHardwareInfo hw_info) { + + auto problem_shape_mnkl = cute::append<4>(problem_shape_mnk, cute::Int<1>{}); + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape); + + Arguments args{}; + if constexpr (!std::is_const_v) { + args.max_swizzle_size = 1 << params.log_swizzle_size_; + } + args.raster_order = params.raster_order_ == RasterOrder::AlongN ? RasterOrderOptions::AlongN : RasterOrderOptions::AlongM; + + return Params::get_grid_shape( + problem_blocks, + to_gemm_coord(cluster_shape), + hw_info, + args.max_swizzle_size, + args.raster_order, + /* truncate_by_problem_size = */true + ); + } + + // Convert CTA-level work tile info to cluster-level tile coord + CUTLASS_DEVICE + cute::Coord + tile_info_to_coord_mnkl(WorkTileInfo work_tile_info) const { + // TileScheduler works at CTA-level, kernel works at cluster-level + int m_coord = idx2crd(work_tile_info.M_idx / scheduler_params.cluster_shape_m_, + scheduler_params.problem_tiles_m_); + int n_coord = idx2crd(work_tile_info.N_idx / scheduler_params.cluster_shape_n_, + scheduler_params.problem_tiles_n_); + int l_coord = idx2crd(work_tile_info.L_idx, + scheduler_params.problem_tiles_l_); + return make_coord(m_coord, n_coord, _, l_coord); + } + + // Returns whether the block assigned this work should compute the epilogue for the corresponding + // output tile. For the basic tile scheduler, this is always true. + CUTLASS_HOST_DEVICE + static bool + compute_epilogue(WorkTileInfo const&, Params const&) { + return true; + } + + CUTLASS_HOST_DEVICE + static bool + compute_epilogue(WorkTileInfo const&) { + return true; + } + + // Performs the reduction across splits for a given output tile. Since this scheduler does + // not split output tiles, no reduction is needed. + template + CUTLASS_DEVICE + static void + fixup(Params const&, WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) {} + + // Performs the reduction across splits for a given output tile. No fixup is required for + // work units returned by this scheduler. + template + CUTLASS_DEVICE + void + fixup(WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) const { } + + // Returns whether the current WorkTileInfo passed in should continue to be used. Since + // this scheduler only schedules work in units of single, full output tiles, the WorkTileInfo + // passed in should not be used after having been processed. + CUTLASS_DEVICE + static bool + continue_current_work(WorkTileInfo&) { + return false; + } + + template + CUTLASS_HOST_DEVICE + static int + get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape problem_shape, TileShape tile_shape) { + // All work units returned by this scheduler cover the entire K iteration + // space of the output tile assigned to the work unit. + return cute::size(cute::ceil_div(cute::get<2>(problem_shape), cute::get<2>(tile_shape))); + } + + CUTLASS_HOST_DEVICE + static uint32_t + get_work_k_tile_start(WorkTileInfo const&) { + // All work units returned by this scheduler start from K tile 0 + return 0u; + } + + CUTLASS_DEVICE + static bool + need_separate_reduction(Params const& params) { + return false; + } + + CUTLASS_DEVICE + bool + is_work_tile_for_reduction(WorkTileInfo const& work_tile_info, Params const& params) { + return false; + } + + template + CUTLASS_DEVICE + void + separate_reduction( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx) { + } + + // Shares the accumulator set with peers in the global workspace + template + CUTLASS_DEVICE + static void + share( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx) { + } + + CUTLASS_DEVICE + static bool + valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) { + return true; + } + + CUTLASS_DEVICE + static bool + requires_separate_reduction(Params const& params) { + return false; + } +public: + // Sink scheduler params as a member + Params scheduler_params; +}; + +} // namespace cutlass::gemm::kernel::detail diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h index be1251ca..85987637 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h @@ -87,6 +87,12 @@ struct PersistentTileSchedulerSm90Params { int32_t log_swizzle_size_ = 0; RasterOrder raster_order_ = RasterOrder::AlongN; + uint32_t problem_tiles_m_ = 0; + uint32_t problem_tiles_n_ = 0; + uint32_t problem_tiles_l_ = 0; + uint32_t cluster_shape_m_ = 0; + uint32_t cluster_shape_n_ = 0; + // Initializes members. This variant of the method should only be used when // problem_shape and tile_shape contain modes of only rank 1. void @@ -127,6 +133,12 @@ struct PersistentTileSchedulerSm90Params { auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n()); + problem_tiles_m_ = problem_blocks_m / cluster_shape.m(); + problem_tiles_n_ = problem_blocks_n / cluster_shape.n(); + problem_tiles_l_ = problem_blocks.z; + cluster_shape_m_ = cluster_shape.m(); + cluster_shape_n_ = cluster_shape.n(); + RasterOrder raster_order = get_rasterization_order( problem_blocks_m, problem_blocks_n, diff --git a/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h index aa2806db..32460b62 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h @@ -1,3 +1,34 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + /*! \file \brief This defines a "fragment" iterator for visiting the fragments of a warp tile that participate in one warp-level mma operation. diff --git a/media/docs/cute/00_quickstart.md b/media/docs/cute/00_quickstart.md index a9c35f1b..47f9c561 100644 --- a/media/docs/cute/00_quickstart.md +++ b/media/docs/cute/00_quickstart.md @@ -6,12 +6,12 @@ The core abstraction of CuTe are the hierarchically multidimensional layouts whi ## System Requirements -CuTe shares CUTLASS 3.0's software requirements, +CuTe shares CUTLASS 3.x's software requirements, including NVCC with a C++17 host compiler. ## Knowledge prerequisites -CuTe is a CUDA C++ library. It requires C++17 +CuTe is a CUDA C++ header-only library. It requires C++17 (the revision of the C++ Standard that was released in 2017). Throughout this tutorial, we assume intermediate C++ experience. @@ -29,8 +29,10 @@ and how to launch kernels. ## Building Tests and Examples CuTe's tests and examples build and run as part of CUTLASS's normal build process. + CuTe's unit tests live in the [`test/unit/cute`](../../../test/unit/cute) subdirectory. -Its examples live in the [`examples/cute`](../../../examples/cute) subdirectory. + +CuTe's examples live in the [`examples/cute`](../../../examples/cute) subdirectory. ## Library Organization @@ -38,9 +40,9 @@ CuTe is a header-only C++ library, so there is no source code that needs buildin | Directory | Contents | |------------------------|------------------------| -| [`include/cute`](../../../include/cute) | Each header in the top level corresponds to one of the fundamental building blocks of CuTe, such as [`Layout`](../../../include/cute/layout.hpp) or [`Tensor`](../../../include/cute/tensor.hpp). | -| [`include/cute/container`](../../../include/cute/container) | Implementations of STL-like container objects, such as tuple, array, aligned array, and array views. | -| [`include/cute/numeric`](../../../include/cute/numeric) | Templates that handle nonstandard floating-point types, unsigned integers, complex numbers, and integer sequence - like fundamental numeric data types. | +| [`include/cute`](../../../include/cute) | Each header in the top level corresponds to one of the fundamental building blocks of CuTe, such as [`Layout`](../../../include/cute/layout.hpp) and [`Tensor`](../../../include/cute/tensor.hpp). | +| [`include/cute/container`](../../../include/cute/container) | Implementations of STL-like objects, such as tuple, array, and aligned array. | +| [`include/cute/numeric`](../../../include/cute/numeric) | Fundamental numeric data types that include nonstandard floating-point types, nonstandard integer types, complex numbers, and integer sequence. | | [`include/cute/algorithm`](../../../include/cute/algorithm) | Implementations of utility algorithms such as copy, fill, and clear that automatically leverage architecture-specific features if available. | | [`include/cute/arch`](../../../include/cute/arch) | Wrappers for architecture-specific matrix-matrix multiply and copy instructions. | | [`include/cute/atom`](../../../include/cute/atom) | Meta-information for instructions in `arch` and utilities like partitioning and tiling. @@ -57,7 +59,7 @@ Other files in this directory discuss specific parts of CuTe. * [`01_layout.md`](./01_layout.md) describes `Layout`, CuTe's core abstraction. -* [`02_layout_operations.md`](./02_layout_operations.md) describes more advanced `Layout` operations and the CuTe layout algebra. +* [`02_layout_algebra.md`](./02_layout_algebra.md) describes more advanced `Layout` operations and the CuTe layout algebra. * [`03_tensor.md`](./03_tensor.md) describes `Tensor`, a multidimensional array abstraction which composes `Layout` @@ -74,5 +76,44 @@ Other files in this directory discuss specific parts of CuTe. * [`0y_predication.md`](./0y_predication.md) explains what to do if a tiling doesn't fit evenly into a matrix. -* [`0z_tma_tensors.md`](./0z_tma_tensors.md) summarizes - how CuTe supports TMA loads and stores. +* [`0z_tma_tensors.md`](./0z_tma_tensors.md) explains an advanced `Tensor` type that CuTe uses to support TMA loads and stores. + +## Quick Tips + +### How do I print CuTe objects on host or device? + +The `cute::print` function has overloads for almost all CuTe types, including Pointers, Integers, Strides, Shapes, Layouts, and Tensors. When in doubt, try calling `print` on it. + +CuTe's print functions work on either host or device. +Note that on device, printing is expensive. +Even just leaving print code in place on device, +even if it is never called +(e.g., printing in an `if` branch that is not taken at run time), +may generate slower code. +Thus, be sure to remove code that prints on device after debugging. + +You might also only want to print on thread 0 of each threadblock, or threadblock 0 of the grid. The `thread0()` function returns true only for global thread 0 of the kernel, that is, for thread 0 of threadblock 0. A common idiom for printing CuTe objects to print only on global thread 0. + +```c++ +if (thread0()) { + print(some_cute_object); +} +``` + +Some algorithms depend on some thread or threadblock, +so you may need to print on threads or threadblocks other than zero. +The header file +[`cute/util/debug.hpp`](../../../include/cute/util/debug.hpp), +among other utilities, +includes the function `bool thread(int tid, int bid)` +that returns `true` if running on thread `tid` and threadblock `bid`. + +#### Other output formats + +Some CuTe types have special printing functions that use a different output format. + +The `cute::print_layout` function will display any rank-2 layout in a plain test table. This is excellent for visualizing the map from coordinates to indices. + +The `cute::print_tensor` function will display any rank-1, rank-2, rank-3, or rank-4 tensor in a plain text multidimensional table. The values of the tensor are printed so you can verify the tile of data is what you expect after a copy, for example. + +The `cute::print_latex` function will print LaTeX commands that you can use to build a nicely formatted and colored tables via `pdflatex`. This work for `Layout`, `TiledCopy`, and `TiledMMA`, which can be very useful to get a sense of layout patterns and partitioning patterns within CuTe. diff --git a/media/docs/cute/01_layout.md b/media/docs/cute/01_layout.md index c1a25ac1..90530128 100644 --- a/media/docs/cute/01_layout.md +++ b/media/docs/cute/01_layout.md @@ -1,145 +1,331 @@ # CuTe Layouts -## Layout - This document describes `Layout`, CuTe's core abstraction. -A `Layout` maps from a logical coordinate space +Fundamentally, a `Layout` maps from coordinate space(s) to an index space. `Layout`s present a common interface to multidimensional array access that abstracts away the details of how the array's elements are organized in memory. This lets users write algorithms that access multidimensional arrays generically, -so that layouts can change, without users' code needing to change. +so that layouts can change, without users' code needing to change. For example, a row-major MxN layout and a column-major MxN layout can be treated identically in software. CuTe also provides an "algebra of `Layout`s." `Layout`s can be combined and manipulated to construct more complicated layouts -and to partition them across other layouts. +and to tile layouts across other layouts. This can help users do things like partition layouts of data over layouts of threads. -## Layouts and Tensors +## Fundamental Types and Concepts -Any of the `Layout`s discussed in this section can be composed with data -- e.g., a pointer or an array -- to create a `Tensor`. -The `Layout`'s logical coordinate space represents the logical "shape" of the data, -e.g., the modes of the `Tensor` and their extents. -The `Layout` maps a logical coordinate into an index, -which is an offset to be used to index into the array of data. +### Integers -For details on `Tensor`, please refer to the -[`Tensor` section of the tutorial](./03_tensor.md). +CuTe makes great use of dynamic (known only at run-time) and static (known at compile-time) integers. + +* Dynamic integers (or "run-time integers") are just ordinary integral types like `int` or `size_t` or `uint16_t`. Anything that is accepted by `std::is_integral` is considered a dynamic integer in CuTe. + +* Static integers (or "compile-time integers") are instantiations of types like `std::integral_constant`. These types encode the value as a `static constexpr` member. They also support casting to their underlying dynamic types, so they can be used in expressions with dynamic integers. CuTe defines its own CUDA-compatibe static integer types `cute::C` along with overloaded math operators so that math on static integers results in static integers. CuTe defines shortcut aliases `Int<1>`, `Int<2>`, `Int<3>` and `_1`, `_2`, `_3` as conveniences, which you should see often within examples. + +CuTe attempts to handle static and dynamic integers identically. In the examples that follow, all dynamic integers could be replaced with static integers and vice versa. When we say "integer" in CuTe, we almost always mean a static OR dynamic integer. -## Shapes and Strides +CuTe provides a number of traits to work with integers. +* `cute::is_integral`: Checks whether `T` is a static or dynamic integer type. +* `cute::is_std_integral`: Checks whether `T` is a dynamic integer type. Equivalent to `std::is_integral`. +* `cute::is_static`: Checks whether `T` is an empty type (so instantiations cannot depend on any dynamic information). Equivalent to `std::is_empty`. +* `cute::is_constant`: Checks that `T` is a static integer AND its value is equivalent to `N`. -A `Layout` is a pair of `Shape` and `Stride`. -Both `Shape` and `Stride` are `IntTuple` types. +See the [`integral_constant` implementations](../../../include/cute/numeric/integral_constant.hpp) for more information. + +### Tuple + +A tuple is a finite ordered list of zero or more elements. +The [`cute::tuple` class](../../../include/cute/container/tuple.hpp) behaves like `std::tuple`, but works on device and host. It imposes restrictions on its template arguments and strips down the implementation for performance and simplicity. ### IntTuple -An `IntTuple` is defined recursively as either a single integer, or a tuple of `IntTuple`s. -This means that `IntTuple`s can be arbitrarily nested. -Operations defined on `IntTuple`s include the following. +CuTe defines the IntTuple concept as either an integer, or a tuple of IntTuples. Note the recursive definition. +In C++, we define [operations on `IntTuple`](../../../include/cute/int_tuple.hpp). + +Examples of `IntTuple`s include: +* `int{2}`, the dynamic integer 2. +* `Int<3>{}`, the static integer 3. +* `make_tuple(int{2}, Int<3>{})`, the tuple of dynamic-2, and static-3. +* `make_tuple(uint16_t{42}, make_tuple(Int<1>{}, int32_t{3}), Int<17>{})`, the tuple of dynamic-42, tuple of static-1 and dynamic-3, and static-17. -* `get(IntTuple)`: The `I`th element of the `IntTuple`. For an `IntTuple` consisting of a single integer, `get<0>` is just that integer. +CuTe reuses the `IntTuple` concept for many different things, +including Shape, Stride, Step, and Coord +(see [`include/cute/layout.hpp`](../../../include/cute/layout.hpp)). + +Operations defined on `IntTuple`s include the following. * `rank(IntTuple)`: The number of elements in an `IntTuple`. A single integer has rank 1, and a tuple has rank `tuple_size`. +* `get(IntTuple)`: The `I`th element of the `IntTuple`, with `I < rank`. For single integers, `get<0>` is just that integer. + * `depth(IntTuple)`: The number of hierarchical `IntTuple`s. A single integer has depth 0, a tuple of integers has depth 1, a tuple that contains a tuple of integers has depth 2, etc. * `size(IntTuple)`: The product of all elements of the `IntTuple`. -We write `IntTuple`s with parenthesis to denote the hierarchy. For example, `6`, `(2)`, `(4,3)`, `(3,(6,2),8)` are all `IntTuple`s. +We write `IntTuple`s with parentheses to denote the hierarchy. For example, `6`, `(2)`, `(4,3)`, and `(3,(6,2),8)` are all `IntTuple`s. + +### Shapes and Strides + +Both `Shape` and `Stride` are `IntTuple` concepts. + +### Layout + +A `Layout` is a tuple of (`Shape`, `Stride`). +Semantically, it implements a mapping from +any coordinate within the Shape to an index via the Stride. + +### Tensor + +A `Layout` can be composed with data -- e.g., a pointer or an array -- to create a `Tensor`. The index generated by the `Layout` is used to subscript an iterator to retrieve the appropriate data. For details on `Tensor`, please refer to the +[`Tensor` section of the tutorial](./03_tensor.md). -## Layout +## Layout Creation and Use -A `Layout` is then a pair of `IntTuple`s. The first element defines the abstract *shape* of the `Layout`, and the second element defines the *strides*, which map from coordinates within the shape to the index space. +A `Layout` is a pair of `IntTuple`s: the `Shape` and the `Stride`. The first element defines the abstract *shape* of the `Layout`, and the second element defines the *strides*, which map from coordinates within the shape to the index space. -Since a `Layout` is just a pair of `IntTuple`s, we can define operations on `Layout`s analogous to those defined on `IntTuple`. +We define many operations on `Layout`s analogous to those defined on `IntTuple`. -* `get(Layout)`: The `I`th sub-layout of the `Layout`. +* `rank(Layout)`: The number of modes in a `Layout`. Equivalent to the tuple size of the `Layout`'s shape. -* `rank(Layout)`: The number of modes in a `Layout`. +* `get(Layout)`: The `I`th sub-layout of the `Layout`, with `I < rank`. -* `depth(Layout)`: The number of hierarchical `Layout`s. A single integer has depth 0, a tuple of integers has depth 1, a tuple that contains a tuple of integers has depth 2, etc. +* `depth(Layout)`: The depth of the `Layout`'s shape. A single integer has depth 0, a tuple of integers has depth 1, a tuple of tuples of integers has depth 2, etc. * `shape(Layout)`: The shape of the `Layout`. * `stride(Layout)`: The stride of the `Layout`. -* `size(Layout)`: The logical extent of the `Layout`. Equivalent to `size(shape(Layout))`. +* `size(Layout)`: The size of the `Layout` function's domain. Equivalent to `size(shape(Layout))`. + +* `cosize(Layout)`: The size of the `Layout` function's codomain (not necessarily the range). Equivalent to `A(size(A) - 1) + 1`. ### Hierarchical access functions -`IntTuple`s and thus `Layout`s can be arbitrarily nested. +`IntTuple`s and `Layout`s can be arbitrarily nested. For convenience, we define versions of some of the above functions that take a sequence of integers, instead of just one integer. This makes it possible to access elements -inside of nested `IntTuple` or `Layout`. -For example, we permit `get(x)`, where `I...` here -and throughout this section is a "C++ parameter pack" -that denotes zero or more (integer) template arguments. -That is, `get(x)` is equivalent to -`get(` $\dots$ `(get(get(x)))` $\dots$ `))`, -where the ellipses are pseudocode and not actual C++ syntax. -These hierarchical access functions include the following. +inside of nested `IntTuple` or `Layout` more easily. +For example, we permit `get(x)`, where `I...` is a "C++ parameter pack" that denotes zero or more (integer) template arguments. These hierarchical access functions include the following. + +* `get(x) := get(...(get(get(x)))...)`. Extract the `IN`th of the ... of the `I1`st of the `I0`th element of `x`. * `rank(x) := rank(get(x))`. The rank of the `I...`th element of `x`. * `depth(x) := depth(get(x))`. The depth of the `I...`th element of `x`. +* `shape(x) := shape(get(x))`. The shape of the `I...`th element of `x`. + * `size(x) := size(get(x))`. The size of the `I...`th element of `x`. -### Vector examples +In the following examples, you'll see use of `size<0>` and `size<1>` to determine loops bounds for the 0th and 1st mode of a layout or tensor. + +### Constructing a Layout + +A `Layout` can be constructed in many different ways. +It can include any combination of compile-time (static) integers +or run-time (dynamic) integers. + +```c++ +Layout s8 = make_layout(Int<8>{}); +Layout d8 = make_layout(8); + +Layout s2xs4 = make_layout(make_shape(Int<2>{},Int<4>{})); +Layout s2xd4 = make_layout(make_shape(Int<2>{},4)); + +Layout s2xd4_a = make_layout(make_shape (Int< 2>{},4), + make_stride(Int<12>{},Int<1>{})); +Layout s2xd4_col = make_layout(make_shape(Int<2>{},4), + LayoutLeft{}); +Layout s2xd4_row = make_layout(make_shape(Int<2>{},4), + LayoutRight{}); + +Layout s2xh4 = make_layout(make_shape (2,make_shape (2,2)), + make_stride(4,make_stride(2,1))); +Layout s2xh4_col = make_layout(shape(s2xh4), + LayoutLeft{}); +``` + +The `make_layout` function returns a `Layout`. +It deduces the types of the function's arguments and returns a `Layout` with the appropriate template arguments. +Similarly, the `make_shape` and `make_stride` functions +return a `Shape` resp. `Stride`. +CuTe often uses these `make_*` functions +due to restrictions around constructor template argument deduction (CTAD) and to avoid having to repeat static or dynamic integer types. + +When the `Stride` argument is omitted, it is generated from the provided `Shape` with `LayoutLeft` as default. The `LayoutLeft` tag constructs strides as an exclusive prefix product of the `Shape` from left to right, without regard to the `Shape`'s hierarchy. This can be considered a "generalized column-major stride generation". The `LayoutRight` tag constructs strides as an exclusive prefix product of the `Shape` from right to left, without regard to the `Shape`'s hierarchy. For shapes of depth one, this can be considered a "row-major stride generation", but for hierarchical shapes the resulting strides may be surprising. For example, the strides of `s2xh4` above could be generated with `LayoutRight`. + +Calling `print` on each layout above results in the following + +``` +s8 : _8:_1 +d8 : 8:_1 +s2xs4 : (_2,_4):(_1,_2) +s2xd4 : (_2,4):(_1,_2) +s2xd4_a : (_2,4):(_12,_1) +s2xd4_col : (_2,4):(_1,_2) +s2xd4_row : (_2,4):(4,_1) +s2xh4 : (2,(2,2)):(4,(2,1)) +s2xh4_col : (2,(2,2)):(_1,(2,4)) +``` + +The `Shape:Stride` notation is used quite often for `Layout`. The `_N` notation is shorthand for a static integer while other integers are dynamic integers. Observe that both `Shape` and `Stride` may be composed of both static and dynamic integers. + +Also note that the `Shape` and `Stride` are assumed to be *congruent*. That is, `Shape` and `Stride` have the same tuple profiles. For every integer in `Shape`, there is a corresponding integer in `Stride`. This can be asserted with +```cpp +static_assert(congruent(my_shape, my_stride)); +``` + +### Using a Layout + +The fundamental use of a `Layout` is to map between coordinate space(s) defined by the `Shape` and an index space defined by the `Stride`. For example, to print an arbitrary rank-2 layout in a 2-D table, we can write the function + +```c++ +template +void print2D(Layout const& layout) +{ + for (int m = 0; m < size<0>(layout); ++m) { + for (int n = 0; n < size<1>(layout); ++n) { + printf("%3d ", layout(m,n)); + } + printf("\n"); + } +} +``` + +which produces the following output for the above examples. + +``` +> print2D(s2xs4) + 0 2 4 6 + 1 3 5 7 +> print2D(s2xd4_a) + 0 1 2 3 + 12 13 14 15 +> print2D(s2xh4_col) + 0 2 4 6 + 1 3 5 7 +> print2D(s2xh4) + 0 2 1 3 + 4 6 5 7 +``` + +We can see static, dynamic, row-major, column-major, and hierarchical layouts printed here. The statement `layout(m,n)` provides the mapping of +the logical 2-D coordinate (m,n) to the 1-D index. + +Interestingly, the `s2xh4` example isn't row-major or column-major. Furthermore, it has three modes but is still interpreted as rank-2 and we're using a 2-D coordinate. Specifically, `s2xh4` has a 2-D multi-mode in the second mode, but we're still able to use a 1-D coordinate for that mode. More on this in the next section, but first we can generalize this another step. Let's use a 1-D coordinate and treat all of the modes of each layout as a single multi-mode. For instance, the following `print1D` function + +```c++ +template +void print1D(Layout const& layout) +{ + for (int i = 0; i < size(layout); ++i) { + printf("%3d ", layout(i)); + } +} +``` + +produces the following output for the above examples. + +``` +> print1D(s2xs4) + 0 1 2 3 4 5 6 7 +> print1D(s2xd4_a) + 0 12 1 13 2 14 3 15 +> print1D(s2xh4_col) + 0 1 2 3 4 5 6 7 +> print1D(s2xh4) + 0 4 2 6 1 5 3 7 +``` + +Any multi-mode of a layout, including the entire layout itself, can accept a 1-D coordinate. More on this in the following sections. + +CuTe provides more printing utilities for visualizing Layouts. The `print_layout` function produces a formatted 2-D table of the Layout's mapping. + +```text +> print_layout(s2xh4) +(2,(2,2)):(4,(2,1)) + 0 1 2 3 + +---+---+---+---+ + 0 | 0 | 2 | 1 | 3 | + +---+---+---+---+ + 1 | 4 | 6 | 5 | 7 | + +---+---+---+---+ +``` + +The `print_latex` function generates LaTeX that can be compiled with `pdflatex` into a color-coded vector graphics image of the same 2-D table. + +### Vector Layouts + +We define a vector as any `Layout` with `rank == 1`. +For example, the layout `8:1` can be interpreted as an 8-element vector whose indices are contiguous. + +``` +Layout: 8:1 +Coord : 0 1 2 3 4 5 6 7 +Index : 0 1 2 3 4 5 6 7 +``` + +Similarly, +the layout `8:2` can be interpreted as an 8-element vector where the indices of the elements are strided by `2`. -We define a vector as any `Shape` and `Stride` pair with `rank == 1`. -For example, the `Layout` +``` +Layout: 8:2 +Coord : 0 1 2 3 4 5 6 7 +Index : 0 2 4 6 8 10 12 14 +``` + +By the above rank-1 definition, we *also* interpret layout `((4,2)):((2,1))` as a vector, since its shape is rank-1. The inner shape looks like a 4x2 column-major matrix, but the extra pair of parenthesis suggest we can interpret those two modes as a 1-D 8-element vector. The strides tell us that the first `4` elements are strided by `2` and then there are `2` of those first elements strided by `1`. ``` -Shape: (8) -Stride: (1) +Layout: ((4,2)):((2,1)) +Coord : 0 1 2 3 4 5 6 7 +Index : 0 2 4 8 1 3 5 7 ``` -defines a contiguous 8-element vector. -For a vector with the same Shape but a Stride of `(2)`, -the interpretation is that the eight elements -are stored at positions 0, 2, 4, $\dots$, 14. +We can see the second set of `4` elements are duplicates of the first `4` with an extra stride of `1`. -By the above definition, we *also* interpret +Consider the layout `((4,2)):((1,4))`. Again, it's `4` elements strided by `1` and then `2` of those first elements strided by `4`. ``` -Shape: ((4,2)) -Stride: ((1,4)) +Layout: ((4,2)):((1,4)) +Coord : 0 1 2 3 4 5 6 7 +Index : 0 1 2 3 4 5 6 7 ``` -as a vector, since its shape is rank 1. The inner shape describes a 4x2 layout of data in column-major order, but the extra pair of parenthesis suggest we can interpret those two modes as a single 1-D 8-element vector instead. Due to the strides, the elements are also contiguous. +As a function from integers to integers, it's identical to `8:1`. It's the identity function. ### Matrix examples -Generalizing, we define a matrix as any `Shape` and `Stride` pair with rank 2. For example, +Generalizing, we define a matrix as any `Layout` that is rank-2. For example, ``` -Shape: (4,2) -Stride: (1,4) +Shape : (4,2) +Stride: (1,4) 0 4 1 5 2 6 3 7 ``` -is a 4x2 column-major matrix, and +is a 4x2 column-major layout with stride-1 down the columns and stride-4 across the rows, and ``` -Shape: (4,2) -Stride: (2,1) +Shape : (4,2) +Stride: (2,1) 0 1 2 3 4 5 6 7 ``` -is a 4x2 row-major matrix. +is a 4x2 row-major layout with stride-2 down the columns and stride-1 across the rows. Majorness is simply which mode has stride-1. -Each of the modes of the matrix can also be split into *multi-indices* like the vector example. -This lets us express more layouts beyond just row major and column major. For example, +Just like the vector layouts, each of the modes of the matrix can also be split into *multi-modes*. +This lets us express more layouts beyond just row-major and column-major. For example, ``` Shape: ((2,2),2) @@ -150,117 +336,200 @@ Stride: ((4,1),2) 5 7 ``` -is also logically 4x2, with a stride of 2 across the rows but a multi-stride down the columns. -Since this layout is logically 4x2, +is also logically 4x2, with stride-2 across the rows but a multi-stride down the columns. The first `2` elements down the column have a stride of `4` and then there is a copy of those with stride-1. Since this layout is logically 4x2, like the column-major and row-major examples above, we can _still_ use 2-D coordinates to index into it. -## Constructing a `Layout` +## Layout Concepts -A `Layout` can be constructed in many different ways. -It can include any combination of compile-time (static) integers -or run-time (dynamic) integers. +In this section, we'll introduce the coordinate sets that `Layout`s accept and how the coordinate mappings and index mappings are computed. -```c++ -auto layout_8s = make_layout(Int<8>{}); -auto layout_8d = make_layout(8); +### Layout compatibility -auto layout_2sx4s = make_layout(make_shape(Int<2>{},Int<4>{})); -auto layout_2sx4d = make_layout(make_shape(Int<2>{},4)); +We say that layout A is *compatible* with layout B if the shape of A is compatible with the shape of B. +Shape A is compatible with shape B if -auto layout_2x4 = make_layout(make_shape (2, make_shape (2,2)), - make_stride(4, make_stride(2,1))); -``` +* the size of A is equal to the size of B and +* all coordinates within A are valid coordinates within B. -The `make_layout` function returns a `Layout`. -It deduces the returned `Layout`'s template arguments from the function's arguments. -Similarly, the `make_shape` and `make_stride` functions -return a `Shape` resp. `Stride`. -CuTe often uses these `make_*` functions, -because constructor template argument deduction (CTAD) -does not work for `cute::tuple` as it works for `std::tuple`. +For example: +* Shape 24 is NOT compatible with Shape 32. +* Shape 24 is compatible with Shape (4,6). +* Shape (4,6) is compatible with Shape ((2,2),6). +* Shape ((2,2),6) is compatible with Shape ((2,2),(3,2)). +* Shape 24 is compatible with Shape ((2,2),(3,2)). +* Shape 24 is compatible with Shape ((2,3),4). +* Shape ((2,3),4) is NOT compatible with Shape ((2,2),(3,2)). +* Shape ((2,2),(3,2)) is NOT compatible with Shape ((2,3),4). +* Shape 24 is compatible with Shape (24). +* Shape (24) is NOT compatible with Shape 24. +* Shape (24) is NOT compatible with Shape (4,6). -## Using a `Layout` +That is, *compatible* is a weak partial order on Shapes as it is reflexive, antisymmetric, and transitive. -The fundamental use of a `Layout` is to map between logical coordinate space(s) and an index space. For example, to print an arbitrary rank-2 layout, we can write the function +### Layouts Coordinates -```c++ -template -void print2D(Layout const& layout) -{ - for (int m = 0; m < size<0>(layout); ++m) { - for (int n = 0; n < size<1>(layout); ++n) { - printf("%3d ", layout(m,n)); - } - printf("\n"); - } -} +With the notion of compatibility above, we emphasize that every `Layout` accepts multiple kinds of coordinates. Every `Layout` accepts coordinates for any `Shape` that is compatible with it. CuTe provides mappings between these sets of coordinates via a colexicographical order. + +Thus, all Layouts provide two fundamental mappings: + +* the map from an input coordinate to the corresponding natural coordinate via the `Shape`, +* and the map from a natural coordinate to the index via the `Stride`. + +#### Coordinate Mapping + +The map from an input coordinate to a natural coordinate is the application of a colexicographical order (reading right to left, instead of "lexicographical," which reads left to right) within the `Shape`. + +Take the shape `(3,(2,3))`, for example. This shape has three coordinate sets: the 1-D coordinates, the 2-D coordinates, and the natural (h-D) coordinates. + +| 1-D | 2-D | Natural | | 1-D | 2-D | Natural | +| ----- | ------- | ----------- |-| ----- | ------- | ----------- | +| `0` | `(0,0)` | `(0,(0,0))` | | `9` | `(0,3)` | `(0,(1,1))` | +| `1` | `(1,0)` | `(1,(0,0))` | | `10` | `(1,3)` | `(1,(1,1))` | +| `2` | `(2,0)` | `(2,(0,0))` | | `11` | `(2,3)` | `(2,(1,1))` | +| `3` | `(0,1)` | `(0,(1,0))` | | `12` | `(0,4)` | `(0,(0,2))` | +| `4` | `(1,1)` | `(1,(1,0))` | | `13` | `(1,4)` | `(1,(0,2))` | +| `5` | `(2,1)` | `(2,(1,0))` | | `14` | `(2,4)` | `(2,(0,2))` | +| `6` | `(0,2)` | `(0,(0,1))` | | `15` | `(0,5)` | `(0,(1,2))` | +| `7` | `(1,2)` | `(1,(0,1))` | | `16` | `(1,5)` | `(1,(1,2))` | +| `8` | `(2,2)` | `(2,(0,1))` | | `17` | `(2,5)` | `(2,(1,2))` | + +Each coordinate into the shape `(3,(2,3))` has two *equivalent* coordinates and all equivalent coordinates map to the same natural coordinate. To emphasize again, because all of the above coordinates are valid inputs, a Layout with Shape `(3,(2,3))` can be used as if it is a 1-D array of 18 elements by using the 1-D coordinates, a 2-D matrix of 3x6 elements by using the 2-D coordinates, or a h-D tensor of 3x(2x3) elements by using the h-D (natural) coordinates. + +The previous 1-D print demonstrates how CuTe identifies 1-D coordinates with a colexicographical ordering of 2-D coordinates. Iterating from `i = 0` to `size(layout)` and indexing into our layout with the single integer coordinate `i`, traverses the 2-D coordinates in this "generalized-column-major" order, even if the layout maps coordinates to indices in a row-major or more complex fashion. + +The function `cute::idx2crd(idx, shape)` is responsible for the coordinate mapping. It will take any coordinate within the shape and compute the equivalent natural coordinate for that shape. +```cpp +auto shape = Shape<_3,Shape<_2,_3>>{}; +print(idx2crd( 16, shape)); // (1,(1,2)) +print(idx2crd(_16{}, shape)); // (_1,(_1,_2)) +print(idx2crd(make_coord( 1,5), shape)); // (1,(1,2)) +print(idx2crd(make_coord(_1{},5), shape)); // (_1,(1,2)) +print(idx2crd(make_coord( 1,make_coord(1, 2)), shape)); // (1,(1,2)) +print(idx2crd(make_coord(_1{},make_coord(1,_2{})), shape)); // (_1,(1,_2)) ``` -which produces the following output for the above examples. +#### Index Mapping + +The map from a natural coordinate to an index is performed by taking the inner product of the natural coordinate with the `Layout`'s `Stride`. + +Take the layout `(3,(2,3)):(3,(12,1))`, for example. Then a natural coordinate `(i,(j,k))` will result in the index `i*3 + j*12 + k*1`. The indices this layout computes are shown in the 2-D table below where `i` is used as the row coordinate and `(j,k)` is used as the column coordinate. ``` -> print2D(layout_2sx4s) - 0 2 4 6 - 1 3 5 7 -> print2D(layout_2sx4d) - 0 2 4 6 - 1 3 5 7 -> print2D(layout_2x4) - 0 2 1 3 - 4 6 5 7 + 0 1 2 3 4 5 <== 1-D col coord + (0,0) (1,0) (0,1) (1,1) (0,2) (1,2) <== 2-D col coord (j,k) + +-----+-----+-----+-----+-----+-----+ + 0 | 0 | 12 | 1 | 13 | 2 | 14 | + +-----+-----+-----+-----+-----+-----+ + 1 | 3 | 15 | 4 | 16 | 5 | 17 | + +-----+-----+-----+-----+-----+-----+ + 2 | 6 | 18 | 7 | 19 | 8 | 20 | + +-----+-----+-----+-----+-----+-----+ ``` -The multi-indices within the `layout_2x4` example are handled as expected and interpreted as a rank-2 layout. +The function `cute::crd2idx(c, shape, stride)` is responsible for the index mapping. It will take any coordinate within the shape, compute the equivalent natural coordinate for that shape (if it is not already), and compute the inner product with the strides. +```cpp +auto shape = Shape <_3,Shape< _2,_3>>{}; +auto stride = Stride<_3,Stride<_12,_1>>{}; +print(crd2idx( 16, shape, stride)); // 17 +print(crd2idx(_16{}, shape, stride)); // _17 +print(crd2idx(make_coord( 1, 5), shape, stride)); // 17 +print(crd2idx(make_coord(_1{}, 5), shape, stride)); // 17 +print(crd2idx(make_coord(_1{},_5{}), shape, stride)); // _17 +print(crd2idx(make_coord( 1,make_coord( 1, 2)), shape, stride)); // 17 +print(crd2idx(make_coord(_1{},make_coord(_1{},_2{})), shape, stride)); // _17 +``` -Note that for `layout_2x4`, we're using a 1-D coordinate for a 2-D multi-index in the second mode. In fact, we can generalize this and treat all of the above layouts as 1-D layouts. For instance, the following `print1D` function +## Layout Manipulation -```c++ -template -void print1D(Layout const& layout) -{ - for (int i = 0; i < size(layout); ++i) { - printf("%3d ", layout(i)); - } -} -``` +### Sublayouts -produces the following output for the above examples. +Sublayouts can be retrieved with `layout` +```cpp +Layout a = Layout>>{}; // (4,(3,6)):(1,(4,12)) +Layout a0 = layout<0>(a); // 4:1 +Layout a1 = layout<1>(a); // (3,6):(4,12) +Layout a10 = layout<1,0>(a); // 3:4 +Layout a11 = layout<1,1>(a); // 6:12 +``` +or `select` +```cpp +Layout a = Layout>{}; // (2,3,5,7):(1,2,6,30) +Layout a13 = select<1,3>(a); // (3,7):(2,30) +Layout a01 = select<0,1,3>(a); // (2,3,7):(1,2,30) +Layout a2 = select<2>(a); // (5):(6) +``` +or `take` +```cpp +Layout a = Layout>{}; // (2,3,5,7):(1,2,6,30) +Layout a13 = take<1,3>(a); // (3,5):(2,6) +Layout a14 = take<1,4>(a); // (3,5,7):(2,6,30) +// take<1,1> not allowed. Empty layouts not allowed. +``` +### Concatenation + +A `Layout` can be provided to `make_layout` to wrap and concatenate +```cpp +Layout a = Layout<_3,_1>{}; // 3:1 +Layout b = Layout<_4,_3>{}; // 4:3 +Layout row = make_layout(a, b); // (3,4):(1,3) +Layout col = make_layout(b, a); // (4,3):(3,1) +Layout q = make_layout(row, col); // ((3,4),(4,3)):((1,3),(3,1)) +Layout aa = make_layout(a); // (3):(1) +Layout aaa = make_layout(aa); // ((3)):((1)) +Layout d = make_layout(a, make_layout(a), a); // (3,(3),3):(1,(1),1) ``` -> print1D(layout_8s) - 0 1 2 3 4 5 6 7 -> print1D(layout_8d) - 0 1 2 3 4 5 6 7 -> print1D(layout_2sx4s) - 0 1 2 3 4 5 6 7 -> print1D(layout_2sx4d) - 0 1 2 3 4 5 6 7 -> print1D(layout_2x4) - 0 4 2 6 1 5 3 7 +or can be combined with `append`, `prepend`, or `replace` +```cpp +Layout a = Layout<_3,_1>{}; // 3:1 +Layout b = Layout<_4,_3>{}; // 4:3 +Layout ab = append(a, b); // (3,4):(1,3) +Layout ba = prepend(a, b); // (4,3):(3,1) +Layout c = append(ab, ab); // (3,4,(3,4)):(1,3,(1,3)) +Layout d = replace<2>(c, b); // (3,4,4):(1,3,3) ``` -This shows explicitly that all of the layouts are simply folded views of an 8-element array. +### Grouping + +Layout modes can be grouped with `group` and flattened with `flatten` +```cpp +Layout a = Layout>{}; // (_2,_3,_5,_7) +Layout b = group<0,2>(a); // ((_2,_3),_5,_7) +Layout c = group<1,3>(b); // ((_2,_3),(_5,_7)) +Layout f = flatten(c); // (_2,_3,_5,_7) +``` + +### Slicing + +`Layout`s can be sliced, but slicing is more appropriate to perform on `Tensor`s. See the [`Tensor` section](./03_tensor.md) for slicing details. ## Summary * The `Shape` of a `Layout` defines its coordinate space(s). * Every `Layout` has a 1-D coordinate space. - This can be used to iterate in a "generalized-column-major" order. + This can be used to iterate over the coordinate spaces in a colexicographical order. * Every `Layout` has a R-D coordinate space, where R is the rank of the layout. - These spaces are ordered _colexicographically_ - (reading right to left, instead of "lexicographically," - which reads left to right). - The enumeration of that order - corresponds to the 1-D coordinates above. + The colexicographical enumeration of the R-D coordinates + correspond to the 1-D coordinates above. - * Every `Layout` has an h-D coordinate space where h is "hierarchical." These are ordered colexicographically and the enumeration of that order corresponds to the 1-D coordinates above. An h-D coordinate is congruent to the `Shape` so that each element of the coordinate has a corresponding element of the `Shape`. + * Every `Layout` has an h-D (natural) coordinate space where h is "hierarchical." These are ordered colexicographically and the enumeration of that order corresponds to the 1-D coordinates above. A natural coordinate is *congruent* to the `Shape` so that each element of the coordinate has a corresponding element of the `Shape`. * The `Stride` of a `Layout` maps coordinates to indices. - * In general, this could be any function from 1-D coordinates (integers) to indices (integers). + * The inner product of the elements of the natural coordinate with the elements of the `Stride` produces the resulting index. + +For each `Layout` there exists an integral `Shape` that is that compatible with that `Layout`. Namely, that integral shape is `size(layout)`. We can then observe that + +> Layouts are functions from integers to integers. - * In `CuTe` we use an inner product of the h-D coordinates with the `Stride` elements. +If you're familiar with the C++23 feature `mdspan`, +this is an important difference between +`mdspan` layout mappings and CuTe `Layout`s. In CuTe, `Layout` is a first class citizen, is natively hierarchical to naturally represent functions beyond row-major and column-major, and can similarly be indexed with a hierarchy of coordinates. +(`mdspan` layout mappings can represent hierarchical functions as well, +but this requires defining a custom layout.) +Input coordinates for an `mdspan` must have the same shape as the `mdspan`; +a multidimensional `mdspan` does not accept 1-D coordinates. diff --git a/media/docs/cute/02_layout_algebra.md b/media/docs/cute/02_layout_algebra.md new file mode 100644 index 00000000..ec42318b --- /dev/null +++ b/media/docs/cute/02_layout_algebra.md @@ -0,0 +1,572 @@ +# CuTe Layout Algebra + +CuTe provides an "algebra of `Layout`s" to support combining layouts in different ways. This algebra includes operations such as + +* `Layout` functional composition, +* a notion of `Layout` "product" to reproduce one layout according to another, and +* a notion of `Layout` "divide" to split one layout according to another. + +Common utilities for building complicated layouts from simpler ones depend on the `Layout` product. Common utilities for partitioning layouts (of data, for example) across other layouts (of threads, for example) depend on the `Layout` divide. All of these utilities rely on the functional composition of `Layout`s. + +In this section, we'll build up the tools of the `Layout` algebra and explain some of these core operations in detail. + +## Coalesce + +In the previous section, we summarized `Layout`s with +> Layouts are functions from integers to integers. + +The `coalesce` operation is a "simplify" on functions from integers to integers. If we only care about input integers, then we can manipulate the shape and number of modes of the `Layout` without changing it as a function. The only thing `coalesce` can't change is the `Layout`'s `size`. + +More specifically, you can find the checked post-conditions in [the `coalesce` unit test](../../../test/unit/cute/core/coalesce.cpp), which we'll reproduce here: +```cpp +// @post size(@a result) == size(@a layout) +// @post depth(@a result) <= 1 +// @post for all i, 0 <= i < size(@a layout), @a result(i) == @a layout(i) +Layout coalesce(Layout const& layout) +``` + +For example, + +```cpp +auto layout = Layout>, + Stride<_1,Stride<_6,_2>>>{}; +auto result = coalesce(layout); // _12:_1 +``` + +where we can see the result has fewer modes and is "simpler." Indeed, this could save us a few operations in the coordinate mapping and index mapping (if those are performed dynamically). + +So, how do we get there? + +* We've already seen that column-major `Layout`s like `(_2,_4):(_1,_2)` act identically to `_8:_1` for 1-D coordinates. +* Modes with size static-1 will always produce a natural coordinate of static-0. They can be ignored no matter the stride. + +Generalizing, consider a layout with just two integral modes, s0:d0 and s1:d1. Denote the result of coalescing this layout as s0:d0 ++ s1:d1. Then, there are four cases: + +1. `s0:d0 ++ _1:d1 => s0:d0`. Ignore modes with size static-1. +2. `_1:d0 ++ s1:d1 => s1:d1`. Ignore modes with size static-1. +3. `s0:d0 ++ s1:s0*d0 => s0*s1:d0`. If the second mode's stride is the product of the first mode's size and stride, then they can be combined. +4. `s0:d0 ++ s1:d1 => (s0,s1):(d0,d1)`. Else, nothing can be done and they must be treated separately. + +That's it! We can flatten any layout and apply the above binary operation to each pair of adjacent modes in order to "coalesce" the modes of the layout. + +### By-mode Coalesce + +Obviously, sometimes we do care about the shape of our `Layout`, but would still like to coalesce. For example, I have a 2-D `Layout` and I would like the result to remain 2-D. + +For this reason, there's an overload of `coalesce` that takes an additional parameter +```cpp +// Apply coalesce at the terminals of trg_profile +Layout coalesce(Layout const& layout, IntTuple const& trg_profile) +``` + +which can be used as follows + +```cpp +auto a = Layout>, + Stride<_1,Stride<_6,_2>>>{}; +auto result = coalesce(a, Step<_1,_1>{}); // (_2,_6):(_1,_2) +// Identical to +auto same_r = make_layout(coalesce(layout<0>(a)), + coalesce(layout<1>(a))); +``` + +This function is recursing into `Step<_1,_1>{}` and applying `coalesce` to the corresponding sublayout whenever it sees an integer (the values don't matter, they're just flags) rather than a tuple. + +> This theme of defining an operation that treats a `Layout` as a "1-D" function from integers to integers and then generalizing to use it for an arbitrarily shaped layout will be a common one! + +## Composition + +Functional composition of `Layout`s is the core of CuTe and is used in just about every higher-level operation. + +Starting again from the observation that `Layout`s are just functions from integers to integers, we can define functional composition that results in another `Layout`. First, an example. + +```text +Functional composition, R := A o B +R(c) := (A o B)(c) := A(B(c)) + +Example +A = (6,2):(8,2) +B = (4,3):(3,1) + +R( 0) = A(B( 0)) = A(B(0,0)) = A( 0) = A(0,0) = 0 +R( 1) = A(B( 1)) = A(B(1,0)) = A( 3) = A(3,0) = 24 +R( 2) = A(B( 2)) = A(B(2,0)) = A( 6) = A(0,1) = 2 +R( 3) = A(B( 3)) = A(B(3,0)) = A( 9) = A(3,1) = 26 +R( 4) = A(B( 4)) = A(B(0,1)) = A( 1) = A(1,0) = 8 +R( 5) = A(B( 5)) = A(B(1,1)) = A( 4) = A(4,0) = 32 +R( 6) = A(B( 6)) = A(B(2,1)) = A( 7) = A(1,1) = 10 +R( 7) = A(B( 7)) = A(B(3,1)) = A(10) = A(4,1) = 34 +R( 8) = A(B( 8)) = A(B(0,2)) = A( 2) = A(2,0) = 16 +R( 9) = A(B( 9)) = A(B(1,2)) = A( 5) = A(5,0) = 40 +R(10) = A(B(10)) = A(B(2,2)) = A( 8) = A(2,1) = 18 +R(11) = A(B(11)) = A(B(3,2)) = A(11) = A(5,1) = 42 +``` + +The absolutely amazing observation is that the function `R(c) = k` defined above can be written down as another `Layout` + +``` +R = ((2,2),3):((24,2),8) +``` + +AND + +``` +compatible(B, R) +``` + +That is, every coordinate of `B` can also be used as a coordinate of `R`. This is an expected property of functional composition because `B` defines the *domain* of `R`. + +You can find many examples and checked post-conditions in [the `composition` unit test](../../../test/unit/cute/core/composition.cpp). The post-conditions are precisely as we just stated. +```cpp +// @post compatible(@a layout_b, @a result) +// @post for all i, 0 <= i < size(@a layout_b), @a result(i) == @a layout_a(@a layout_b(i))) +Layout composition(LayoutA const& layout_a, LayoutB const& layout_b) +``` + +### Computing Composition + +First, a few observations: + +* `B = (B_0, B_1, ...)`. A layout can be expressed as the concatenation of its sublayouts. + +* `A o B = A o (B_0, B_1, ...) = (A o B_0, A o B_1, ...)`. When `B` is injective, composition is left-distributive with concatenation. + +With the above, we can assume without loss of generality that `B = s:d` is a layout with integral shape and stride. We can also assume that `A` is a flattened, coalesced layout. + +When `A` is integral, `A = a:b`, the result is rather trivial: `R = A o B = a:b o s:d = s:(b*d)`. But when `A` is multimodal, we need to be more careful. + +Put into words, `A o B = A o s:d`, for integral `s` and `d` means that we want (1) every `d`th element of `A`, and then (2) keep the first `s` of those strided elements. + +1. Every `d`th element of `A` can be computed by "dividing out" the first `d` elements from the shape of `A`. For an array of integers representing the shape, this is computed as +```cpp +void shape_div(int* shapeA, int N, int& strideB) { + for (int i = 0; i < N; ++i) { + assert(shapeA[i] % strideB == 0 or + strideB % shapeA[i] == 0); + int new_shape = ceil_div(shapeA[i], strideB); + int new_stride = ceil_div(strideB, shapeA[i]); + shapeA[i] = new_shape; + strideB = new_stride; + } +} +``` +which progressively "removes" the first `strideB` elements from `shapeA` starting from the left. For example, +* `(6,2) / 2 => (3,2)` +* `(6,2) / 3 => (2,2)` +* `(6,2) / 6 => (1,2)` +* `(6,2) / 12 => (1,1)` +* `(3,6,2,8) / 6 => (1,3,2,8)` +* `(3,6,2,8) / 9 => (1,2,2,8)` +* `(42,16,3) / 2 => (21,16,3)` +* `(42,16,3) / 6 => ( 7,16,3)` + +As you may have noticed, we can only divide shapes by certain values and get a sensible result. This is called the **divisibility condition** and is enforced by the `assert` in the above code and statically checked in CuTe when possible. + +2. The first `s` elements of the strided `A` layout can be computed by "modding out" the first `s` elements from the shape of `A`. For an array of integers representing the shape, this is computed as +```cpp +void shape_mod(int* shapeA, int N, int& shapeB) { + for (int i = 0; i < N; ++i) { + assert(shapeA[i] % shapeB == 0 or + shapeB % shapeA[i] == 0); + int new_shapeA = min(shapeA[i], shapeB); + int new_shapeB = ceil_div(shapeB, shapeA[i]); + shapeA[i] = new_shapeA; + shapeB = new_shapeB; + } +} +``` +which progressibly "keeps" the first `shapeB` elements from `shapeA` starting from the left. For example, +* `(6,2) % 2 => (2,1)` +* `(6,2) % 3 => (3,1)` +* `(6,2) % 6 => (6,1)` +* `(6,2) % 12 => (6,2)` +* `(3,6,2,8) % 6 => (3,2,1,1)` +* `(3,6,2,8) % 9 => (3,3,1,1)` +* `(1,2,2,8) % 2 => (1,2,1,1)` +* `(1,2,2,8) % 16 => (1,2,2,4)` + +Again, this operation must satisfy the divisibility condition to yield a sensible result. This is enforced by the `assert` in the above code and statically checked in CuTe when possible. + +Clearly, CuTe does not use arrays to store shapes or strides and the above code is for explication only. CuTe works with shapes and strides as `IntTuple`s and the implementation is expressed as algorithmic `fold`s which carefully account for static and dynamic integers. + +#### Example 1 -- Reshape a layout into a matrix + +`20:2 o (5,4):(4,1)`. + +This describes interpreting the layout `20:2` +as a 5x4 matrix in a row-major order. + +1. ` = 20:2 o (5:4,4:1)`. Concatenation of sublayouts. + +2. ` = (20:2 o 5:4, 20:2 o 4:1)`. Left distributivity. + + * `20:2 o 5:4 => 5:8`. Trivial case. + * `20:2 o 4:1 => 4:2`. Trivial case. + +3. ` = (5:8, 4:2)`. + +4. ` = (5,4):(8,2)`. Concatenation of sublayouts. + +#### Example 2 -- Reshape a layout into a matrix + +`(10,2):(16,4) o (5,4):(1,5)` + +This describes interpreting the layout `(10,2):(16,4)` +as a 5x4 matrix in a col-major order. + +1. ` = (10,2):(16,4) o (5:1,4:5)`. Concatenation of sublayouts. + +2. ` = ((10,2):(16,4) o 5:1, (10,2):(16,4) o 4:5)`. Left distributivity. + + * `(10,2):(16,4) o 5:1 => (5,1):(16,4)`. Mod out the shape `5`. + * `(10,2):(16,4) o 4:5 => (2,2):(80,4)`. Div out the stride `5`. + +3. ` = ((5,1):(16,4), (2,2):(80,4))`. Collect results. + +4. ` = (5:16, (2,2):(80,4))`. By-mode coalesce. + +5. ` = (5,(2,2))):(16,(80,4))`. Concatenation of sublayouts. + +We get exactly this result with CuTe +if we use compile-time shapes and strides. +The following C++ code prints `(_5,(_2,_2)):(_16,(_80,_4))`. + +```cpp +Layout a = make_layout(make_shape (Int<10>{}, Int<2>{}), + make_stride(Int<16>{}, Int<4>{})); +Layout b = make_layout(make_shape (Int< 5>{}, Int<4>{}), + make_stride(Int< 1>{}, Int<5>{})); +Layout c = composition(a, b); +print(c); +``` + +If we use dynamic integers, the following C++ code prints `((5,1),(2,2)):((16,4),(80,4))`. + +```cpp +Layout a = make_layout(make_shape (10, 2), + make_stride(16, 4)); +Layout b = make_layout(make_shape ( 5, 4), + make_stride( 1, 5)); +Layout c = composition(a, b); +print(c); +``` + +The results may _look_ different but are the mathematically the same. The 1s in the shape don't affect the layout as a mathematical function from 1-D coordinates to integers or as a function from 2-D coordinates to integers. In the dynamic case, CuTe can not coalesce the dynamic size-1 modes to "simplify" the layout due to the static rank and type of the tuples containing them. + +### By-mode Composition + +Similar to by-mode `coalesce` and building up to a generic tiling operation, sometimes we do care about the shape of the `A` layout and would still like to apply `composition` to individual modes. For example, I have a 2-D `Layout` and would like some sublayout of the elements down the columns and another sublayout of elements across the rows. + +For this reason, `composition` also works when its second parameter -- the `B` -- is a `Tiler`. In general, a tiler is a layout or a tuple-of-layouts (note the generalization on `IntTuple`), which can be used as follows +```cpp +// (12,(4,8)):(59,(13,1)) +auto a = make_layout(make_shape (12,make_shape ( 4,8)), + make_stride(59,make_stride(13,1))); +// <3:4, 8:2> +auto tiler = make_tile(Layout<_3,_4>{}, // Apply 3:4 to mode-0 + Layout<_8,_2>{}); // Apply 8:2 to mode-1 + +// (_3,(2,4)):(236,(26,1)) +auto result = composition(a, tiler); +// Identical to +auto same_r = make_layout(composition(layout<0>(a), get<0>(tiler)), + composition(layout<1>(a), get<1>(tiler))); +``` +We often use the `` notation to distinguish `Tiler`s from the concatenation-of-sublayouts notation `(LayoutA, LayoutB, ...)` that we used previously. + +The `result` in the above code can be depicted as the 3x8 sublayout of the original layout highlighted in the figure below. +

+ composition1.png +

+ +For convenience, CuTe also interprets `Shape`s as a tiler as well. A `Shape` is interpreted as tuple-of-layouts-with-stride-1: +```cpp +// (12,(4,8)):(59,(13,1)) +auto a = make_layout(make_shape (12,make_shape ( 4,8)), + make_stride(59,make_stride(13,1))); +// (8, 3) +auto tiler = make_shape(Int<3>{}, Int<8>{}); +// Equivalent to <3:1, 8:1> +// auto tiler = make_tile(Layout<_3,_1>{}, // Apply 3:1 to mode-0 +// Layout<_8,_1>{}); // Apply 8:1 to mode-1 + +// (_3,(4,2)):(59,(13,1)) +auto result = composition(a, tiler); +``` +where `result` can be depicted as the 3x8 sublayout of the original layout highlighted in the figure below. +

+ composition2.png +

+ +## Composition Tilers + +In summary, a `Tiler` is one of the following objects. +1. A `Layout`. +2. A tuple of `Tiler`s. +3. A `Shape`, which will be interpreted as a tiler of `Layout`s with stride-1. + +Any of the above can be used as the second argument in `composition`. With (1), we think of the `composition` as between two functions from integers to integers, no matter the ranks of the layouts. With (2) and (3), the `composition` is performed on each pair of corresponding modes of `A` and `B`, until case (1) is found. + +This allows composition to be applied by-mode to retrieve arbitrary sublayouts of specified modes of a tensor ("Give me the 3x5x8 subblock of this MxNxL tensor") but also allows entire tiles of data to be reshaped and reordered as if they were 1-D vectors ("Reorder this 8x16 block of data into a 32x4 block using this weird order of elements"). We will see the by-mode cases appear often when we are tiling for threadblocks in examples that follow. We will see 1-D reshaping and reordering when we want to apply arbitrary partitioning patterns for threads and values in MMAs in examples that follow. + +## Complement + +Before getting to "product" and "divide," we need one more operation. We can think of `composition` as a layout `B` that is "selecting" certain coordinates from another layout `A`. But what about the coordinates that aren't "selected"? To implement generic tiling, we want to be able to select arbitrary elements -- the tile -- and to describe the layout of those tiles -- the leftovers, or the "rest." + +The `complement` of a layout attempts to find another layout that represents the "rest" -- the elements that aren't touched by the layout. + +You can many examples and checked post-conditions in [the `complement` unit test](../../../test/unit/cute/core/complement.cpp). The post-conditions include +```cpp +// @post cosize(make_layout(@a layout_a, @a result))) >= @a cosize_hi +// @post cosize(@a result) >= round_up(@a cosize_hi, cosize(@a layout_a)) +// @post for all i, 1 <= i < size(@a result), +// @a result(i-1) < @a result(i) +// @post for all i, 1 <= i < size(@a result), +// for all j, 0 <= j < size(@a layout_a), +// @a result(i) != @a layout_a(j) +Layout complement(LayoutA const& layout_a, Integral const& cosize_hi) +``` +That is, the complement `R` of a layout `A` with respect to an integer `M` satisfies the following properties. +1. The size (and cosize) of `R` is bounded by `M`. +2. `R` is *ordered*. That is, the strides of `R` are positive and increasing. This means that `R` is unique. +3. `A` and `R` have *disjoint* codomains. `R` attempts to "complete" the codomain of `A`. + +### Complement Examples + +`complement` is most effective on static shapes and strides, so consider all integers below to be static. Similar examples for dynamic shapes and strides can be found in the unit test. + +* `complement(4:1, 24)` is `6:4`. Note that `(4,6):(1,4)` has cosize `24`. The layout `4:1` is effectively repeated 6 times with `6:4`. + +* `complement(6:4, 24)` is `4:1`. Note that `(6,4):(4,1)` has cosize `24`. The "hole" in `6:4` is filled with `4:1`. + +* `complement((4,6):(1,4), 24)` is `1:0`. Nothing needs to be appended. + +* `complement(4:2, 24)` is `(2,4):(1,8)`. Note that `(4,(2,4)):(2,(1,8))` has cosize `24`. The "hole" in `4:2` is filled with `2:1` first, then everything is repeated 4 times with `4:8`. + +* `complement((2,4):(1,6), 24)` is `3:2`. Note that `((2,4),3):((1,6),2)` has cosize `24` and produces unique indices. + +* `complement((2,2):(1,6), 24)` is `(3,2):(2,12)`. Note that `((2,4),(2,2)):((1,6),(2,12))` has cosize `24` and produces unique indices. + +

+ complement1.png +

+As a visualization, the above figure depicts the codomain of the last example. The image of the original layout `(2,2):(1,6)` is colored in gray. The complement effectively "repeats" the original layout (displayed in the other colors) such that the codomain size of the result is `24`. The complement `(3,2):(2,12)` can be viewed as the "layout of the repetition." + +## Division (Tiling) + +Finally, we can define the division of a `Layout` by another `Layout`. Functions that divide a layout into components are useful as a basis for tiling and partitioning layouts. + +In this section, we'll define `logical_divide(Layout, Layout)`, which again considers all `Layout`s as 1-D functions from integers to integers, and then use that definition to create multidimensional `Layout` divides. + +Informally, `logical_divide(A, B)` splits a layout `A` into two modes -- in the first mode are all elements pointed to by `B` and in the second mode are all elements not pointed to by `B`. + +Formally, this can be written as + +$A \oslash B := A \circ (B,B^*)$ + +and implemented as +```cpp +template +auto logical_divide(Layout const& layout, + Layout const& tiler) +{ + return composition(layout, make_layout(tiler, complement(tiler, size(layout)))); +} +``` +Note that this is defined only in terms of concatenation, composition, and complement. + +So what is that? + +> in the first mode are all elements pointed to by `B` + +This is clearly composition, `A o B`. + +> in the second mode are all elements not pointed to by `B` + +The elements NOT pointed to by `B` sounds like a complement, `B*`, up to the size of `A`. As we've seen above in the `complement` section, this can be described as the "layout of the repetition of `B`." If `B` is the "tiler", then `B*` is the layout of the tiles. + +### Logical Divide 1-D Example + +Consider tiling the 1-D layout `A = (2,4,3):(4,1,8)` with the tiler `B = 4:2`. Informally, this means that we have a 1-D vector of 24 elements in some storage order defined by `A` and we want to extract tiles of 4 elements strided by 2. + +This is computed in the three steps described in the implementation above. +* Complement of `B = 4:2` under `size(A) = 24` is `B* = (2,3):(1,8)`. +* Concantenation of `(B,B*) = (4,(2,3)):(2,(1,8))`. +* Composition of `A = (2,4,3):(4,1,8)` with `(B,B*)` is then `((2,2),(2,3)):((4,1),(2,8))`. + +

+ divide1.png +

+ +The above figure depicts `A` as a 1-D layout with the elements pointed to by `B` highlighted in gray. The layout `B` describes our "tile" of data, and there are six of those tiles in `A` shown by each of the colors. After the divide, the first mode of the result is the tile of data and the second mode of the result iterates over each tile. + +### Logical Divide 2-D Example + +Using the `Tiler` concept defined above, this immediately generalizes to multidimensional tiling. The below example simply applies `layout_divide` by-mode to the cols and rows of a 2-D layout using a `Tiler`. + +Similar to the 2-D composition example above, consider a 2-D layout `A = (9,(4,8)):(59,(13,1))` and want to apply `3:3` down the columns (mode-0) and `(2,4):(1,8)` across the rows (mode-1). This means the tiler can be written as `B = <3:3, (2,4):(1,8)>`. + +

+ divide2.png +

+ +The above figure depicts `A` as a 2-D layout with the elements pointed to by `B` highlighted in gray. The layout `B` describes our "tile" of data, and there are twelve of those tiles in `A` shown by each of the colors. After the divide, the first mode of each mode of the result is the tile of data and the second mode of each mode iterates over each tile. In that sense, this operation can be viewed as a kind of `gather` operation or as simply a permutation on the rows and cols. + +Note that the first mode of each mode of the result is the sublayout `(3,(2,4)):(236,(13,52))` and is precisely the result we would have received if we had applied `composition` instead of `logical_divide`. + +### Zipped, Tiled, Flat Divides + +It's easy to see the tiles when they are highlighted in the images above, but working with them can still be awkward. How would you slice out the `3`rd tile or the `7`th tile or the `(1,2)`th tile so you could continue working on it? + +Enter the convenience flavors of `logical_divide`. Suppose we have a `Layout` and a `Tiler` of some shape, then each operation will apply `logical_divide`, but potentially rearrange the modes into more convenient forms. +```text +Layout Shape : (M, N, L, ...) +Tiler Shape : + +logical_divide : ((TileM,RestM), (TileN,RestN), L, ...) +zipped_divide : ((TileM,TileN,...), (RestM,RestN,L,...)) +tiled_divide : ((TileM,TileN,...), RestM, RestN, L, ...) +flat_divide : (TileM, TileN, ..., RestM, RestN, L, ...) +``` + +For example, the `zipped_divide` function applies `logical_divide`, and then gathers the "subtiles" into a single mode and the "rest" into a single mode. +```cpp +// A: shape is (9,32) +auto layout_a = make_layout(make_shape (Int< 9>{}, make_shape (Int< 4>{}, Int<8>{})), + make_stride(Int<59>{}, make_stride(Int<13>{}, Int<1>{}))); +// B: shape is (3,8) +auto tiler = make_tile(Layout<_3,_3>{}, // Apply 3:3 to mode-0 + Layout, // Apply (2,4):(1,8) to mode-1 + Stride<_1,_8>>{}); + +// ((TileM,RestM), (TileN,RestN)) with shape ((3,3), (8,4)) +auto ld = logical_divide(layout_a, tiler); +// ((TileM,TileN), (RestM,RestN)) with shape ((3,8), (3,4)) +auto zd = zipped_divide(layout_a, tiler); +``` +Then, the offset to the `3`rd tile is `zd(0,3)`. The offset to the `7`th tile is `zd(0,7)`. The offset to the `(1,2)`th tile is `zd(0,make_coord(1,2))`. The tile itself always has layout `layout<0>(zd)`. Indeed, it is always the case that + +`layout<0>(zipped_divide(a, b)) == composition(a, b)`. + +We note that `logical_divide` preserves the *semantics* of the modes while permuting the elements within those modes -- the `M`-mode of layout `A` is still the `M`-mode of the result and the `N`-mode of layout `A` is still the `N`-mode of the result. + +This is not the case with `zipped_divide`. The mode-0 in the `zipped_divide` result is the `Tile` itself (of whatever rank the `Tiler` was) and mode-1 is the layout of those tiles. It doesn't always make sense to plot these as 2-D layouts, because the `M`-mode is now more aptly the "tile-mode" and the `N`-mode is more aptly the "rest-mode". Regardless, we still can plot the resulting layout as 2-D as shown below. + +

+ divide3.png +

+ +We've kept each tile as its color in the previous images for clarity. Clearly, iterating across tiles is now equivalent to iterating across a row of this layout and iterating over elements within a tile is equivalent to iterating down a column of this layout. As we'll see in the `Tensor` section, this can be used to great effect in partitioning within or across tiles of data. + +## Product (Tiling) + +Finally, we can define the product of a Layout by another Layout. In this section, we'll define `logical_product(Layout, Layout)`, which again considers all `Layout`s as 1-D functions from integers to integers, and then use that definition to create multidimensional `Layout` products. + +Informally, `logical_product(A, B)` results in a two mode layout where the first mode is the layout `A` and the second mode is the layout `B` but with each element replaced by a "unique replication" of layout `A`. + +Formally, this can be written as + +$A \otimes B := (A, A^* \circ B)$ + +and implemented in CuTe as +```cpp +template +auto logical_product(Layout const& layout, + Layout const& tiler) +{ + return make_layout(layout, composition(complement(layout, size(layout)*cosize(tiler)), tiler)); +} +``` +Note that this is defined only in terms of concatenation, composition, and complement. + +So what is that? + +> where the first mode is the layout `A` + +This is clearly just a copy of `A`. + +> the second mode is the layout `B` but with each element replaced by a "unique replication" of layout `A`. + +The "unique replication" of layout `A` sounds like complement, `A*`, up to the cosize of `B`. As we've seen in the `complement` section, this can be described as the "layout of the repetition of `A`". If `A` is the "tile", then `A*` is the layout of repetitions that are available for `B`. + +### Logical Product 1-D Example + +Consider reproducing the 1-D layout `A = (2,2):(4,1)` according to `B = 6:1`. Informally, this means that we have a 1-D layout of 4 elements defined by `A` and we want to reproduce it 6 times. + +This is computed in the three steps described in the implementation above. +* Complement of `A = (2,2):(4,1)` under `6*4 = 24` is `A* = (2,3):(2,8)`. +* Composition of `A* = (2,3):(2,8)` with `B = 6:1` is then `(2,3):(2,8)`. +* Concatenation of `(A,A* o B) = ((2,2),(2,3)):((4,1),(2,8))`. + +

+ product1.png +

+ +The above figure depicts `A` and `B` as a 1-D layouts. The layout `B` describes the number and order of repetitions of `A` and they are colored for clarity. After the product, the first mode of the result is the tile of data and the second mode of the result iterates over each tile. + +Note that the result is identical to the result of the 1-D Logical Divide example. + +Of course, we can change the number and order of the tiles in the product by changing `B`. + +

+ product2.png +

+ +For example, in the above image with `B = (4,2):(2,1)`, there are 8 repeated tiles instead of 6 and the tiles are in a different order. + +### Logical Product 2-D Example + +We can use the by-mode `tiler` strategies previously developed to write multidimensional products as well. + +

+ product2d.png +

+ +The above image demonstates the use of a `tiler` to apply `logical_product` by-mode. Despite this **not being the recommended approach**, the result is a rank-2 layout consisting of 2x5 row-major block that is tiled across a 3x4 col-major arrangement. + +The reason **this is not the recommended approach** is that the `tiler B` in the above expression is highly unintuitive. In fact, it requires perfect knowledge of the shape and strides of `A` in order to construct. We would like to express "Tile Layout `A` according to Layout `B`" in a way that makes `A` and `B` independent and is much more intuitive. + +#### Blocked and Raked Products + +The `blocked_product(LayoutA, LayoutB)` and `raked_product(LayoutA, LayoutB)` are interesting, more intuitive, rank-sensitive transformations on top of 1-D `logical_product` that let us express the intuitive Layout products that we most often want to express. + +A key observation in the implementation of these functions are the compatibility post-conditions of `logical_product`: +``` +// @post rank(result) == 2 +// @post compatible(layout_a, layout<0>(result)) +// @post compatible(layout_b, layout<1>(result)) +``` + +Because `A` is always compatible with mode-0 of the result and `B` is always compatible with mode-1 of the result, if we made `A` and `B` the same rank then we could "reassociate" like-modes after the product. That is, the "col" mode in `A` could be combined with the "col" mode in `B` and the "row" mode in `A` could be combined with the "row" mode in `B`, etc. + +This is exactly what `blocked_product` and `raked_product` do and it is why they are called rank-sensitive. Unlike other CuTe functions that take `Layout` arguments, these care about the top-level rank of the arguments so that each mode can be reassociated after the `logical_product`. + +

+ productblocked2d.png +

+ +The above image shows the same result as the `tiler` approach, but with much more intuitive arguments. A 2x5 row-major layout is arranged as a tile in a 3x4 col-major arrangement. Also note that `blocked_product` went ahead and `coalesced` mode-0 for us. + +Similarly, `raked_product` combines the modes slightly differently. Instead of the resulting "col" mode being constructed from the `A` "col" mode then the `B` "col" mode, the resulting "col" mode is constructed from the `B` "col" mode then the `A` "col" mode. + +

+ productraked2d.png +

+ +This results in the "tile" `A` now being interleaved or "raked" with the "layout-of-tiles" `B` instead of appearing as blocks. Other references call this a "cyclic distribution." + +### Zipped and Tiled Products + +Similar to `zipped_divide` and `tiled_divide`, the `zipped_product` and `tiled_product` simply rearrange the modes that result from a by-mode `logical_product`. + +```text +Layout Shape : (M, N, L, ...) +Tiler Shape : + +logical_product : ((M,TileM), (N,TileN), L, ...) +zipped_product : ((M,N), (TileM,TileN,L,...)) +tiled_product : ((M,N), TileM, TileN, L, ...) +flat_product : (M, N, TileM, TileN, L, ...) +``` diff --git a/media/docs/cute/02_layout_operations.md b/media/docs/cute/02_layout_operations.md deleted file mode 100644 index 7860cb7d..00000000 --- a/media/docs/cute/02_layout_operations.md +++ /dev/null @@ -1,833 +0,0 @@ -# CuTe Layout Operations - -CuTe provides an "algebra of `Layout`s." -`Layout`s can be combined and manipulated -to construct more complicated `Layout`s. -This includes tiling and partitioning `Layout`s across other `Layout`s. -In this section, we explain some of these core operations in detail. - -## How do I print CuTe objects on host or device? - -CuTe comes with different ways to print CuTe objects. -You can print human-readable text, -or you can print LaTeX commands for generating -a beautifully formatted and colored table -describing the CuTe object. -Both of these can be helpful for reasoning about or debugging -layouts, copy atoms, or matrix multiply atoms -(don't worry, we'll explain all of these things in this tutorial). - -CuTe's print functions work on either host or device. -Note that on device, printing is expensive. -Even just leaving print code in place on device, -even if it is never called -(e.g., printing in an `if` branch that is not taken at run time), -may generate slower code. -Thus, be sure to remove code that prints on device after debugging. - -The following code examples assume that you have a -`using namespace cute;` statement in scope. - -### Printing human-readable text - -The `cute::print` function has overloads for almost all CuTe types, including Pointers, Layout, Shape, Stride, and Tensors. When in doubt, try calling `print` on it. You might also only want to print on thread 0 of each thread block, or block 0 of the grid. The `thread0()` function returns true only for global thread 0 of the kernel. A typical idiom for printing CuTe objects to print only on thread 0 of block 0. - -```c++ -if (thread0()) { - print(some_cute_object); -} -``` - -Some algorithms do different things on different threads or blocks, -so you might sometimes need to print on threads or blocks other than zero. -The header file -[`cute/util/debug.hpp`](../../../include/cute/util/debug.hpp), -among other utilities, -includes the function `bool thread(int tid, int bid)` -that returns `true` if running on thread `tid` and block `bid`. - -Some CuTe types have special printing functions that use a different output format. -For example, `print_layout` can display a rank-2 layout in a table -(using plain text formatting). -It has an overload taking a rank-2 matrix layout and a thread layout, -that displays a table with the mapping between threads and values. - -### Printing LaTeX output - -The `cute::print_latex` function works like `cute::print`, -but prints LaTeX commands that you can use -to generate a nicely formatted and colored table. - -## Fundamental types - -### Layout and its components - -This directory includes -[an overview of CuTe's fundamental types for describing layouts](./01_layout.md). - -#### Tuple - -CuTe starts with a Tuple, which is a finite ordered list of zero or more elements. -In C++, we identify a Tuple with the -[`cute::tuple` class](../../../include/cute/container/tuple.hpp). -`cute::tuple` behaves like `std::tuple`, but it works on device or host, -and it imposes restrictions on its template arguments for performance and simplicity. - - -#### IntTuple - -CuTe then defines an IntTuple as either an integer, or a Tuple of IntTuple. -This recursive definition lets us build arbitrarily nested layouts. -In C++, we identify an IntTuple with [`IntTuple`](../../../include/cute/int_tuple.hpp), -which is just an alias of `cute::tuple`. -Any of the following are thus valid template arguments of IntTuple. - -1. "Run-time integers" (or "static integers") - are just ordinary integral types like `int` or `size_t`. - -2. "Compile-time integers" include `std::integral_constant` - or subclasses of it that CuTe defines, - such as `Int` (see below). - These types all have in common - that the value is encoded in the type itself - (as a public `static constexpr value` member). - CuTe defines aliases `_1`, `_2`, `_3` etc. - to the types `Int<1>`, `Int<2>`, `Int<3>` etc. - -3. `IntTuple` with any valid template arguments. - -CuTe reuses IntTuple for many different things, -including Shape, Stride, Step, and Coord -(see [`include/cute/layout.hpp`](../../../include/cute/layout.hpp)). -In C++, Shape, Stride, Step, and Coord are all aliases for IntTuple. - -### Layout - -A Layout is a tuple of (Shape, Stride). -Semantically, it implements a mapping from -a "logical" Shape-shaped (multidimensional) index, -to a "physical" 1-D index into an array. -Here is an example of a 2 x 3 array with static strides (3, 1). - -```c++ -Layout layout = make_layout(make_shape (_2{}, _3{}), - make_stride(_3{}, _1{})); -print_layout(layout); -for (int i = 0; i < size(layout); ++i) { - print(layout(i)); - print(", "); -} -print("\n"); -print(layout(1, 1)); -print("\n"); -``` - -This code produces the following text output. - -```text -(_2,_3):(_3,_1) - 0 1 2 - +---+---+---+ - 0 | 0 | 1 | 2 | - +---+---+---+ - 1 | 3 | 4 | 5 | - +---+---+---+ -0, 3, 1, 4, 2, 5, -4 -``` - -`print(layout(1, 1))` prints the mapping of -the logical 2-D coordinate (1,1) to the 1-D index, which is 4. -You can see that from the table, -which shows the left logical index as the "row," -and the right logical index as the "column." - -### Underscore (`_`) - -An Underscore is a special type used for array slices. The underscore punctuation `_` is a constant instance of Underscore. It acts like `:` (the colon punctuation) in Python or Fortran array slices. See [`include/cute/underscore.hpp`](../../../include/cute/underscore.hpp). - -### Tile - -"A Tile is not a Layout, it's a tuple of Layouts or Tiles or Underscores." -See [`include/cute/tile.hpp`](../../../include/cute/tile.hpp). - -The algebraic layout operations discussed below are defined on `Layout`s, but `Tile` allows these operations to recurse and to be applied to sublayouts or particular modes of a given Layout. These are referred to as by-mode operations. - -See the section on "Logical Divide" to see an example of using `Tile` to extract portions of a row-mode and portions of a column-mode independently. - -## Layout definitions and operations - -### Layouts are functions from integers (logical 1-D coordinate) to integers (1-D index) - -The `for` loop in the above print example shows how CuTe identifies 1-D coordinates with a column-major layout of logical 2-D coordinates. Iterating from `i = 0` to `size(layout)` (which is 6), and indexing into our layout with the single integer coordinate `i`, traverses the layout in column-major fashion, even though this is a row-major layout. You can see this from the output of the `for` loop (0, 3, 1, 4, 2, 5). CuTe calls this index `i` a "1-D coordinate," versus the "natural coordinate," which would be the logical 2-D coordinate. - -If you're familiar with the C++23 feature `mdspan`, -this is an important difference between -`mdspan` layout mappings and CuTe `Layout`s. -`mdspan` layout mappings are *one way*: -they always take a multidimensional logical coordinate, -and they return an integer offset. -Depending on the strides, -the offset may skip over elements of the physical 1-D array. -Thus, `mdspan`'s offset does NOT mean the same thing as -the 1-D logical coordinate `i` in the `for` loop above. -You can iterate correctly over any CuTe `Layout` -by using the 1-D logical coordinate. -`mdspan` doesn't have an idea of a 1-D logical coordinate. - -### Rank, depth, size, cosize - -*Rank*: the tuple size of the layout's shape. - -*Depth*: the depth of the layout's shape. A single integer has depth 0. A tuple has depth 1 + the max depth of its components. - -*Size*: Size of the shape; size of the domain of the function. This is the product of all extents in the layout's shape. - -*Cosize*: Size of the function's codomain (not necessarily the range); for a layout A, A(size(A) - 1) + 1. (Here, we use size(A) - 1 as a 1-D logical coordinate input.) - -### Layout compatibility - -We say that layouts A and B are *compatible* if their shapes are compatible. Shape A is compatible with shape B if any natural coordinate of A is also a valid coordinate for B. - -### Flatten - -The `flatten` operation "un-nests" a potentially nested Layout. For example, - -```c++ -Layout layout = Layout, _1>, - Stride, _0>>{}; -Layout flat_layout = flatten(layout); -``` - -results in `flat_layout` having the following type - -```text -Layout, Stride<_3, _1, _0>> -``` - -and - -```c++ -Layout layout = Layout>, - Stride<_4, Stride<_1, _16>>>{}; -Layout flat_layout = flatten(layout); -``` - -results in `flat_layout` having the following type - -```text -Layout, Stride<_4, _1, _16>> -``` - -Hierarchical Layouts and flattening let us reinterpret tensors in place as matrices, matrices as vectors, vectors as matrices, etc. This lets us implement arbitrary tensor contractions as batched matrix multiply, by combining the contraction modes into a single mode, and combining the A, B, C, and "batch" modes as needed to reach the desired form. - -### Coalesce - -The `coalesce` operation first flattens the layout, then combines all the modes that are possible to combine, starting with mode 0 (the leftmost mode) and moving right. If all the modes can be combined, then this results in a 1-D layout expressing what array elements the original layout accesses. - -For example, - -```text -layout: (_2,(_1,_6)):(_1,(_6,_2)) -coalesce(layout): _12:_1 -``` - -What does it mean to "combine" modes? In the above example, the flattened layout is (2, 1, 6) : (1, 6, 2). - -1. If we look at the leftmost two modes, this is just a vector of length 2 and stride 1. The middle mode has extent 1, so the corresponding stride 6 would not be observed anyway. This leaves us with (2, 6) : (1, 2). - -2. The intermediate result (2, 6) : (1, 2) is just a 2 x 6 column-major matrix, which can be coalesced into a vector of length 12 and stride 1. - -More formally, "combining all the modes" means a left fold, where the binary operation that combines two modes has three cases. - -1. If the leftmost layout is s1:d1, and the next layout is 1:d0, then combine into s1:d1. This generalizes Step 1 above. If a mode has extent 1, we can't observe its stride, so we can skip the mode. - -2. If the leftmost layout is 1:d1, and the next layout is s0:d0, then combine into s0:d0. Again, if a mode has extent 1, we can't observe its stride, so we can skip the mode. - -3. If the leftmost layout is s1:d1, and the next layout is s0 : s1*d1, then combine into s0 * s1 : d1. This generalizes Step 2 above. One can call this "noticing a column-major layout sequence." - -That's it! For example, the result of coalescing the row-major layout (2, 2) : (2, 1) is (2, 2) : (2, 1), the same layout, because none of the above three cases applies. - -### Complement - -#### Definition - -The complement B of a layout A with respect to an integer M satisfies the following properties. - -1. $A$ and $B$ are *disjoint*: $A(x) \neq B(x)$ for all $x \neq 0$ in the domain of $A$. - -2. B is *ordered*: $B(x-1) \lt B(x)$ for all $x$ in $\{0, 1, \dots, size(B) - 1\}$. - -3. B is *bounded* by M: $size(B) \geq M / size(A)$, and $cosize(B) \leq floor(M / cosize(A)) * cosize(A)$. - -Regarding disjointness: we need to specify $x \neq 0$ because CuTe layouts are linear. That is, if the domain is nonempty, the range always contains zero. - -Regarding the ordered property: CuTe layouts are hierarchically strided, so this implies that if size(B) is nonzero, then the strides of B are all positive. - -#### Examples - -complement(4:1, 24) is 6:4. - -1. The result is disjoint of 4:1, so it must have a stride of at least 4 (since it includes 0, but must skip over 1, 2, 3). - -2. The size of the result is $\geq 24 / 4 = 6$. (This plus Step (1) means that the cosize is at least 24.) - -3. The cosize of the result is $\leq (24 / 4) * 4 = 24$. (This plus Step (2) means that the cosize is exactly 24.) - -4. The only (1-D) layout with size 6 and cosize 24 is 6:4. - -complement(6:4, 24) is 4:1. - -1. 4:1 is disjoint of 6:4, but so is s:d - for any s > 0 and d > 20. - -2. The size of the result is $\geq 24 / 6 = 4$. - -3. The cosize of the result is $\leq (24 / 21) * 21 = 21$. - -4. The stride cannot be greater than 20 - (else (2) would contradict (3)), - so it must be less than 4. - -5. This leaves 4:1 by elimination. - -### Composition - -Layouts are functions, so composition of layouts is just composition of functions. The composition $A \circ B$ means "apply the layout B first, then treat the result as a 1-D logical coordinate input to the layout A, and apply A to it." Very often, this composition can be represented as another Layout. - -#### Rules for computing composition - -Both humans and CuTe compute composition using the following rules. - -1. $A \circ B$ has a shape that is compatible with B. In function composition, the rightmost function defines the domain. For `Layout`s this means that any valid coordinate for $B$ can also be used as a coordinate for $A \circ B$. - -2. Concatenation: A layout can be expressed as the concatenation of its sublayouts. We denote concatenation with parentheses: $B = (B_0,B_1,...)$. The CuTe function `make_layout`, when given zero or more `Layout`s, concatenates them. - -3. Composition is (left-)distributive with concatenation: $A \circ B = A \circ (B_0, B_1, ...) = (A \circ B_0, A \circ B_1, ...)$. - -4. "Base case": For layouts $A = a : b$ and $B = c : d$ with integral shape and stride, $A \circ B = R = c : (b * d)$. - -5. By-mode composition: Let $\langle B, C \rangle$ (angle brackets, not parentheses) - denote a tuple of two layouts B and C, not their concatenation. Let $A = (A_0, A_1)$. - Then, $A \circ \langle B, C \rangle = (A_0, A_1) \circ \langle B, C \rangle = (A_0 \circ B, A_1 \circ C)$. - This allows the application of composition independently to sublayouts of $A$. - -#### Examples: Reshape a vector into a matrix - -This section gives two composition examples. Both start with a vector with layout $20:2$ (that is, the vector has 20 elements, and the stride between each is 2). They compose this vector with a 4 x 5 matrix layout. This effectively "reshapes" the vector in place into a matrix. - -##### Example 1 - -$20:2 \circ (4,5) : (1,4)$. - -This describes interpreting the vector $20:2$ -as a 4 x 5 column-major matrix. - -The resulting layout has shape $(4,5)$, -because in function composition, -the rightmost function defines the domain. -What are the strides? - -1. A layout can be expressed as the concatenation of its sublayouts, - so $(4,5) : (1,4)$ is $(4:1, 5:4)$. - -2. Composition is distributive, so - $20:2 \circ (4:1, 5:4)$ is $(20:2 \circ 4:1, 20:2 \circ 5:4)$. - -3. $20:2 \circ 4:1$ has shape 4 (rightmost function defines the domain) - and stride $2 = 2 \cdot 1$. - -4. $20:2 \circ 5:4$ has shape 5 and stride $8 = 2 \cdot 4$. - -5. Result: (4:2, 5:8), which by concatenation is (4,5) : (2,8). - -#### Example 2 - -$20:2 \circ (4,5) : (5,1)$. - -This describes interpreting the vector 20:2 -as a 4 x 5 row-major matrix. - -The resulting layout has shape $(4,5)$, just as before. What are the strides? - -1. By deconcatenation, $(4,5) : (5,1)$ is $(4:5, 5:1)$. - -2. Composition is distributive, so $20:2 \circ (4:5, 5:1)$ is $(20:2 \circ 4:5, 20:2 \circ 5:1)$. - -3. $20:2 \circ 4:5$ has shape $4$ and stride $10 = 2 \cdot 5$. - -4. $20:2 \circ 5:1$ has shape $5$ and stride $2 = 2 \cdot 1$. - -5. Result: (4:10, 5:2), which by concatenation is (4,5) : (10,2). - -#### Example: Reshape a matrix into another matrix - -The composition $((20,2):(16,4) \circ (4,5):(1,4))$ -expresses reshaping the matrix with layout (20,2):(16:4), -into a 4 x 5 matrix in a column-major way. - -1. By deconcatenation, $(4,5) : (1,4)$ is $(4:1, 5:4)$. - -2. Composition is distributive, so $(20,2):(16,4) \circ (4:1, 5:4)$ is $((20,2):(16,4) \circ 4:1, (20,2):(16,4) \circ 5:4)$. - -3. $(20,2):(16,4) \circ 4:1$ has shape $4$ and stride $16$. (4:1 expresses picking the first 4 consecutive elements of (20,2):(16,4). These elements run down the 0th column (leftmost mode) of the layout, whose stride is 16.) - -4. $(20,2):(16,4) \circ 5:4$ has shape $5$ and stride $64 = 4 \cdot 16$. - -5. Result: $(4:16, 5:64)$, which by concatenation is $(4,5) : (16,64)$. - -We get exactly this result with CuTe -if we use compile-time shapes and strides. -The following C++ code prints `(_4,_5):(_16,_64).` - -```c++ -using namespace cute; -auto a = make_layout(make_shape(Int<20>{}, _2{}), make_stride(_16{}, _4{})); -auto b = make_layout(make_shape( _4{}, _5{}), make_stride( _1{}, _4{})); -auto c = composition(a, b); -printf("\n"); -print(c); -``` - -Results may _look_ different (but are the same mathematically) -if we use run-time integers. -The following C++ code prints `((4,1),(5,1)):((16,4),(64,4)).` - -```c++ -using namespace cute; -auto a = make_layout(make_shape(20, 2), make_stride(16, 4)); -auto b = make_layout(make_shape( 4, 5), make_stride( 1, 4)); -auto c = composition(a, b); -printf("\n"); -print(c); -``` - -((4,1),(5,1)):((16,4),(64,4)) is effectively the same layout -as (4,5) : (16,64), because the 1s in the shape don't affect the layout -(as a mathematical function from one integer to one integer). -CuTe chooses not to simplify layout computations -with run-time values in them as much as it could, -because simplifications involving run-time values have a run-time cost. - -### Product - -CuTe includes four different kinds of layout products. - -1. `logical_product` - -2. `blocked_product` - -3. `raked_product` - -4. `tiled_product` - -`logical_product(A, B)` results in a layout where each element of layout B -has been replaced by a "copy" of layout A. -The other three products offer variations of this idea. - -#### Example: Tiled matrix - -Suppose that I want to make a matrix consisting of 3 x 4 tiles -in a row-major arrangement, -where each tile is a 2 x 2 column-major matrix. - -The Layout of each tile (tile) has Shape (2,2) and Stride (1,2). - -The Layout of the "matrix of tiles" (`matrix_of_tiles`) -has Shape (3,4) and Stride (4,1). - -##### Blocked product: the intuitive tiling - -If I were to deduce by hand what the layout of the tiled matrix should be, -it would look like this. - -| | (0,0) | (1,0) | (0,1) | (1,1) | (0,2) | (1,2) | (0,3) | (1,3) | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | -| (0,0) | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | -| (1,0) | 1 | 3 | 5 | 7 | 9 | 11 | 13 | 15 | -| (0,1) | 16 | 18 | 20 | 22 | 24 | 26 | 28 | 30 | -| (1,1) | 17 | 19 | 21 | 23 | 25 | 27 | 29 | 31 | -| (0,2) | 32 | 34 | 36 | 38 | 40 | 42 | 44 | 46 | -| (1,2) | 33 | 35 | 37 | 39 | 41 | 43 | 45 | 47 | - -The row and column labels use the equivalence of 1-D logical coordinates and 2-D column-major coordinates. The left index in each pair is the row resp. column coordinate of the tile, while the right index in each pair is the row resp. column coordinate of the matrix-of-tiles. The resulting layout has Shape ((2, 3), (2, 4)), and Stride ((1, 16), (2, 4)), and the second mode can be coalesced. The Shape ((2, 3), (2, 4)) is hierarchical, but it is still rank-2 and can be drawn in 2D as above. Note how the row mode of the tile remains part of the row mode of the product, and the column mode of the tile remains a column mode of the product. - -The above layout is what `blocked_product(tile, matrix_of_tiles)` produces. -A critical use case for blocked product is "tiling" an "atom" -(some tile that relates to a hardware feature) over a matrix. - -```c++ -Layout tile = Layout, - Stride<_1,_2>>{}; -Layout matrix_of_tiles = Layout, - Stride<_4,_1>>{}; - -print_layout(blocked_product(tile, matrix_of_tiles)); -``` - -##### Logical product - -The logical product `logical_product(tile, matrix_of_tiles)` -results in Shape ((2, 2), (3, 4)) and Stride ((1, 2), (16, 4)). - -| | (0,0) | (1,0) | (2,0) | (0,1) | (1,1) | (2,1) | (0,2) | (1,2) | (2,2) | (0,3) | (1,3) | (2,3) | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| (0,0) | 0 | 16 | 32 | 4 | 20 | 36 | 8 | 24 | 40 | 12 | 28 | 44 | -| (1,0) | 1 | 17 | 33 | 5 | 21 | 37 | 9 | 25 | 41 | 13 | 29 | 45 | -| (0,1) | 2 | 18 | 34 | 6 | 22 | 38 | 10 | 26 | 42 | 14 | 30 | 46 | -| (1,1) | 3 | 19 | 35 | 7 | 23 | 39 | 11 | 27 | 43 | 15 | 31 | 47 | - -Note how the tile appears in the leftmost column and is reproduced -in each column in the same order as the matrix-of-tiles. That is, -the tile can be indexed through the first mode of the result and the -matrix-of-tiles can be indexed through the second mode. - -```c++ -Layout tile = Layout, - Stride<_1,_2>>{}; -Layout matrix_of_tiles = Layout, - Stride<_4,_1>>{}; - -print_layout(logical_product(tile, matrix_of_tiles)); -``` - -##### Raked product - -The raked product `raked_product(tile, matrix_of_tiles)` results in -Shape ((3, 2), (4, 2)) and Stride ((16, 1), (4, 2)). - -| | (0,0) | (1,0) | (2,0) | (3,0) | (0,1) | (1,1) | (2,1) | (3,1) | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | -| (0,0) | 0 | 4 | 8 | 12 | 2 | 6 | 10 | 14 | -| (1,0) | 16 | 20 | 24 | 28 | 18 | 22 | 26 | 30 | -| (2,0) | 32 | 36 | 40 | 44 | 34 | 38 | 42 | 46 | -| (0,1) | 1 | 5 | 9 | 13 | 3 | 7 | 11 | 15 | -| (1,1) | 17 | 21 | 25 | 29 | 19 | 23 | 27 | 31 | -| (2,1) | 33 | 37 | 41 | 45 | 35 | 39 | 43 | 47 | - -The tile is now interleaved or "raked" with the other 3x4 matrix-of-tiles -instead of appearing as blocks. Other references call this a "cyclic -distribution." - -This might look familiar if you have ever used ScaLAPACK. -It expresses a 2-D block cyclic distribution of a 6 x 8 matrix -over 4 processes in a 2 x 2 "process grid." See -["The Two-dimensional Block-Cyclic Distribution"](https://netlib.org/scalapack/slug/node75.html#sec2dbcd) -and -["Local Storage Scheme and Block-Cyclic Mapping"](https://netlib.org/scalapack/slug/node76.html#seclocalstorage) -in the ScaLAPACK Users' Guide. - -In general, `logical_product` and these variations can produce any interleaving, -including blocked, cyclic, by-mode blocked/cyclic, and intermediate interleavings -that don't have common names. - -```c++ -Layout tile = Layout, - Stride<_1,_2>>{}; -Layout matrix_of_tiles = Layout, - Stride<_4,_1>>{}; - -print_layout(raked_product(tile, matrix_of_tiles)); -``` - -### Division - -The previous section covered layout products, -that reproduce one layout over another. -This section covers layout *division*. -Functions that divide a layout into components are useful -as a basis for tiling and partitioning layouts. - -For example, consider folding a vector into a matrix. -We could imagine an operation, called `logical_divide`, - -```c++ -Layout vec = Layout<_16,_3>{}; // 16 : 3 -Layout col = Layout< _4,_1>{}; // 4 : 1 -Layout mat = logical_divide(vec, col); // (4,4) : (3,12) -``` - -that "takes" the first 4 elements of the vector into the first mode -and leaves the "rest" in the second mode. This is a column-major matrix -view of the data in `vec`. -What if we want a row-major matrix view? - -```c++ -Layout vec = Layout<_16,_3>{}; // 16 : 3 -Layout col = Layout< _4,_4>{}; // 4 : 4 -Layout mat = logical_divide(vec, col); // (4,4) : (12,3) -``` - -Now, every fourth element of the vector is in the first mode and -the "rest" are in the second mode. -Multidimensional, hierarchical indices let us extend this operation -to any layout that "divides" the vector. - -```c++ -Layout vec = Layout<_16,_3>{}; // 16 : 3 -Layout col = Layout< _4,_2>{}; // 4 : 2 -Layout mat = logical_divide(vec, col); // (4,(2,2)) : (6,(3,24)) -``` - -```c++ -Layout vec = Layout<_16,_3>{}; // 16 : 3 -Layout col = Layout, - Stride<_4,_1>>{}; // (2,2) : (4,1) -Layout mat = logical_divide(vec, col); // ((2,2),(2,2)) : ((12,3),(6,24)) -``` - -All of the above examples produce a 4x4 matrix -that can be indexed and treated like a normal 4x4 matrix, -but each has a different underlying layout. -Thus, our algorithms can be written using logical coordinates, -without needing to address the detailed indexing that each layout requires. - -CuTe includes 3 different kinds of layout division operations. - -1. `logical_divide` - -2. `zipped_divide` - -3. `tiled_divide` - -We will summarize these in the sections that follow. - -#### Logical divide - -##### Example worked in detail - -This section will work the following logical divide example in detail. - -```c++ -Layout a = make_layout(24, 2); -Layout b = make_layout( 4, 2); -Layout c = logical_divide(a, b); -``` - -Logical divide produces a rank-2 `Layout`, -where mode 0 (the leftmost mode) corresponds to the divisor `b`, -and mode 1 (the rightmost mode) corresponds to the "remainder." -Intuitively, the remainder of 24 divided by 4 is 6, -so we know that mode 1 has 6 elements. -We just don't know its shape yet. - -CuTe defines `logical_divide(a, b)` as -`composition(a, make_layout(b, complement(b, size(a))))`. -Here, `size(a)` is 24. -What is `complement(b, 24)`? -Intuitively, it means "the remainder," -what's left over after applying `b` to 0, 1, 2, $\dots$, 23. - -The layout 4:2 means "take 4 elements at even-numbered indices." -The following table overlays the range of 4:2 -atop the complement's codomain 0, 1, $\dots$, 23. - -| Range of 4:2 | 0 | | 2 | | 4 | | 6 | | | | | | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| Codomain | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | $\dots$ | 23 | - -Layouts are linear, so their range must include zero. -The complement of 4:2 with respect to 24 is thus a layout whose range - -* includes zero; - -* does not include any other elements of the range of 4:2 - (i.e., satisfies the disjoint property; see above); and - -* includes as much of 0, 1, $\dots$, 23 as possible - (so that it forms the "remainder" of 4:2 with respect to 24). - -Intuitively, the range of the complement must look like this: -0, 1, 8, 9, 16, 17. -The resulting layout is ordered. -It has size 6 and cosize 18, -so it satisfies the bounded property (see above). -This is the layout (2, 3) : (1, 8). -(Going from this intuitive sense of the complement -to knowing how to compute it directly -is out of scope for this part of the tutorial.) - -The following table shows 4:2 with its complement (2, 3) : (1, 8). - -| Range of 4:2 | 0 | | 2 | | 4 | | 6 | | | | | | | | | | | | | | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| Codomain | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | $\dots$ | 23 | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| Range of complement | 0 | 1 | | | | | | | 8 | 9 | | | | | | | 16 | 17 | | | - -Now we know that `logical_divide`(24:2, 4:2) is -`composition`(24:2, `make_layout`(4:2, (2,3):(1,8))). -The composition of two layouts has the shape of the second (rightmost) layout, -so the resulting shape is (4, (2, 3)). -We see that the leftmost mode 4 corresponds to the divisor 4:2, -and the rightmost mode (2, 3) describes what's "left over" -from the original shape 24. - -What are the strides? -We can start from the leftmost mode. -4:2 takes every other element (the even-numbered elements) of 24:2. -That's a stride-2 thing, striding over a stride-2 thing. -The resulting stride is 4. -Similarly, the stride 2 of 24:2 -doubles the two strides of the rightmost mode. -The resulting layout is (4, (2, 3)) : (4, (2, 16)). - -##### Tiling example - -Suppose I have the 6 x 8 matrix from the Raked Product section -and want to "collect" the `tile`, turning the Raked Product into -the Blocked Product. - -To do this, we would like to gather two elements from the column -and leave the rest, then gather two elements from the row and leave the rest. -Thus, we want to apply `logical_divide` independently to the rows and cols -in order to retrieve the appropriate elements. - -In code, we copy the Layout from the result of the Raked Product section, then -specify the elements in the rows and cols we would like to gather. - -```c++ -Layout raked_prod = Layout,Shape <_4,_2>>, - Stride,Stride<_4,_2>>>{}; -Tile subtile = make_tile(Layout<_2,_3>{}, // Gather elements 2 : 3 from mode 0 - Layout<_2,_4>{}); // Gather elements 2 : 4 from mode 1 - -print_layout(logical_divide(raked_prod, subtile)); -``` - -Indeed, this does produce the result from the Blocked Product section. - -| | (0,0) | (1,0) | (0,1) | (1,1) | (0,2) | (1,2) | (0,3) | (1,3) | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | -| (0,0) | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | -| (1,0) | 1 | 3 | 5 | 7 | 9 | 11 | 13 | 15 | -| (0,1) | 16 | 18 | 20 | 22 | 24 | 26 | 28 | 30 | -| (1,1) | 17 | 19 | 21 | 23 | 25 | 27 | 29 | 31 | -| (0,2) | 32 | 34 | 36 | 38 | 40 | 42 | 44 | 46 | -| (1,2) | 33 | 35 | 37 | 39 | 41 | 43 | 45 | 47 | - -Of course, any other rearrangement of the rows and cols is also valid. - -#### Zipped divide - -The `zipped_divide` function applies `logical_divide`, and then gathers the -"subtiles" into a single mode and the "rest" into a single mode. - -For example, if we apply `zipped_divide` instead of `logical_divide` in the example above, - -```c++ -Layout raked_prod = Layout,Shape <_4,_2>>, - Stride,Stride<_4,_2>>>{}; -Tile subtile = make_tile(Layout<_2,_3>{}, // Gather elements 2 : 3 from mode 0 - Layout<_2,_4>{}); // Gather elements 2 : 4 from mode 1 - -print_layout(zipped_divide(raked_prod, subtile)); -``` - -then we get the result - -| | (0,0) | (1,0) | (2,0) | (0,1) | (1,1) | (2,1) | (0,2) | (1,2) | (2,2) | (0,3) | (1,3) | (2,3) | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| (0,0) | 0 | 16 | 32 | 4 | 20 | 36 | 8 | 24 | 40 | 12 | 28 | 44 | -| (1,0) | 1 | 17 | 33 | 5 | 21 | 37 | 9 | 25 | 41 | 13 | 29 | 45 | -| (0,1) | 2 | 18 | 34 | 6 | 22 | 38 | 10 | 26 | 42 | 14 | 30 | 46 | -| (1,1) | 3 | 19 | 35 | 7 | 23 | 39 | 11 | 27 | 43 | 15 | 31 | 47 | - -Note that this is the same layout as the result in the Logical Product section. -That is, the first mode is our original tile (and can be interpreted as a 2x2 matrix itself) -and the second mode is its logical layout within the raked layout. - -#### More Examples of Divide - -For brevity, shapes can be used with `logical_divide` and `tiled_divide` to quickly split and tile modes of a tensor. For example, this C++ code - -```c++ -Layout layout = Layout, - Stride< _1,_128,_0>>{}; -Shape tile_shape = make_shape(_4{},_8{}); -Layout logical_divided_tile = logical_divide(layout, tile_shape); -Layout zipped_divided_tile = zipped_divide(layout, tile_shape); - -print("layout : "); print(layout); print("\n"); -print("tile_shape : "); print(tile_shape); print("\n"); -print("logical_divided_tile : "); print(logical_divided_tile); print("\n"); -print("zipped_divided_tile : "); print(zipped_divided_tile); print("\n\n"); -``` - -produces the following output when we vary `layout`. - -```text -full_layout : (_12,_32,_6):(_1,_128,_0) -tile_shape : (_4,_8) -logical_divided_tile : ((_4,_3),(_8,_4),_6):((_1,_4),(_128,_1024),_0) -zipped_divided_tile : ((_4,_8),(_3,_4,_6)):((_1,_128),(_4,_1024,_0)) - -full_layout : (_12,(_4,_8),_6):(_1,(_32,_512),_0) -tile_shape : (_4,_8) -logical_divided_tile : ((_4,_3),((_4,_2),_4),_6):((_1,_4),((_32,_512),_1024),_0) -zipped_divided_tile : ((_4,(_4,_2)),(_3,_4,_6)):((_1,(_32,_512)),(_4,_1024,_0)) -``` - -This code - -```c++ -Layout layout = make_layout(Shape<_8,_8>{}, - Stride<_8,_1>{}); -Layout tile = make_tile(make_layout(Shape<_4>{}), - make_layout(Shape<_2>{})); -print("layout: "); -print_layout(layout); -print("\n"); -print("tile: "); -print(tile); -print("\n"); -print("logical_divide: "); -print_layout(logical_divide(layout, tile)); -print("zipped_divide: "); -print_layout(zipped_divide(layout, tile)); -``` - -results in the following layouts. - -

- logical_divide-and-zipped_divide -

- -This code - -```c++ -Layout layout = make_layout(Shape<_8,_8>{}, - Stride<_8,_1>{}); -Layout tile = make_tile(make_layout(Shape<_2>{}), - make_layout(Shape<_4>{})); -print("layout: "); -print_layout(layout); -print("\n"); -print("tile: "); -print(tile); -print("\n"); -print("logical_divide: "); -print_layout(logical_divide(layout, tile)); -print("zipped_divide: "); -print_layout(zipped_divide(layout, tile)); -``` - -results in the following layouts. - -

- logical_divide-and-zipped_divide-2 -

- -#### Tiled divide - -The `tiled_divide` function works like `zipped_divide`, -except that it unpacks the second mode. This is useful when you have a `Tile` that describes all of the elements for a particular operation, for example, and want to gather those together but retain the logical shape of those tiles within the original layout. That is, - -```text -Layout Shape : (M, N, L, ...) -Tile Shape : -Tiled Result : ((M', N'), m, n, L, ...) -``` - -where `m` is `M / M'` and `n` is `N / N'`. -We can consider `m` as the "number of `Tile`s in `M`" and `n` as the "number of `Tile`s in `N`". This style of operation is common when applying MMA Atoms and Copy Atoms. diff --git a/media/docs/cute/0t_mma_atom.md b/media/docs/cute/0t_mma_atom.md index d742851f..c79ae124 100644 --- a/media/docs/cute/0t_mma_atom.md +++ b/media/docs/cute/0t_mma_atom.md @@ -142,13 +142,13 @@ directory, in header files starting with `mma_traits`. An `MMA_Traits` specialization defines the following public type aliases. -* `ElementDVal`: Compute type of the D matrix +* `ValTypeD`: Compute type of the D matrix -* `ElementAVal`: Compute type of the A matrix +* `ValTypeA`: Compute type of the A matrix -* `ElementBVal`: Compute type of the B matrix +* `ValTypeB`: Compute type of the B matrix -* `ElementCVal`: Compute type of the C matrix +* `ValTypeC`: Compute type of the C matrix * `Shape_MNK`: Logical MxNxK shape of the MMA operation @@ -172,10 +172,10 @@ It looks like this. template <> struct MMA_Traits { - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; using Shape_MNK = Shape<_8,_8,_4>; using ThrID = SM70_QuadPair; @@ -207,10 +207,10 @@ We first take a look at how we would take the ISA semantics of thread and data p The HMMA NT above uses types: ```cpp - using ElementDVal = float; - using ElementAVal = half_t; - using ElementBVal = half_t; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; ``` The rest of the `MMA_Traits` will be described in units of these types. diff --git a/media/images/cute/complement1.png b/media/images/cute/complement1.png new file mode 100644 index 00000000..24fbb68f Binary files /dev/null and b/media/images/cute/complement1.png differ diff --git a/media/images/cute/composition1.png b/media/images/cute/composition1.png new file mode 100644 index 00000000..0d330a66 Binary files /dev/null and b/media/images/cute/composition1.png differ diff --git a/media/images/cute/composition2.png b/media/images/cute/composition2.png new file mode 100644 index 00000000..7581a5d7 Binary files /dev/null and b/media/images/cute/composition2.png differ diff --git a/media/images/cute/divide1.png b/media/images/cute/divide1.png new file mode 100644 index 00000000..534666e5 Binary files /dev/null and b/media/images/cute/divide1.png differ diff --git a/media/images/cute/divide2.png b/media/images/cute/divide2.png new file mode 100644 index 00000000..7f6f95de Binary files /dev/null and b/media/images/cute/divide2.png differ diff --git a/media/images/cute/divide3.png b/media/images/cute/divide3.png new file mode 100644 index 00000000..a5073faf Binary files /dev/null and b/media/images/cute/divide3.png differ diff --git a/media/images/cute/product1.png b/media/images/cute/product1.png new file mode 100644 index 00000000..7d966323 Binary files /dev/null and b/media/images/cute/product1.png differ diff --git a/media/images/cute/product2.png b/media/images/cute/product2.png new file mode 100644 index 00000000..572beb04 Binary files /dev/null and b/media/images/cute/product2.png differ diff --git a/media/images/cute/product2d.png b/media/images/cute/product2d.png new file mode 100644 index 00000000..b13a9fb3 Binary files /dev/null and b/media/images/cute/product2d.png differ diff --git a/media/images/cute/productblocked2d.png b/media/images/cute/productblocked2d.png new file mode 100644 index 00000000..84862272 Binary files /dev/null and b/media/images/cute/productblocked2d.png differ diff --git a/media/images/cute/productraked2d.png b/media/images/cute/productraked2d.png new file mode 100644 index 00000000..7d121ff4 Binary files /dev/null and b/media/images/cute/productraked2d.png differ diff --git a/python/cutlass/backend/__init__.py b/python/cutlass/backend/__init__.py index f1dce8d7..9a4e2f67 100644 --- a/python/cutlass/backend/__init__.py +++ b/python/cutlass/backend/__init__.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + from cutlass.backend.arguments import * from cutlass.backend.c_types import * from cutlass.backend.compiler import ArtifactManager diff --git a/python/cutlass/backend/arguments.py b/python/cutlass/backend/arguments.py index 2fd988fc..b1e04a14 100644 --- a/python/cutlass/backend/arguments.py +++ b/python/cutlass/backend/arguments.py @@ -56,16 +56,9 @@ def __init__( **kwargs, ) -> None: # tensor_C can be interpreted as the bias with bias=True in keyword args - if "bias" in kwargs.keys(): - self.bias = kwargs["bias"] - else: - # by default, tensor_C is not bias - self.bias = False + self.bias = kwargs.get("bias", False) - if "stream" in kwargs.keys(): - self.stream = kwargs["stream"] - else: - self.stream = cuda.CUstream(0) + self.stream = kwargs.get("stream", cuda.CUstream(0)) # RMM buffers used to track tensor lifetime self.buffers = {} diff --git a/python/cutlass/backend/conv2d_operation.py b/python/cutlass/backend/conv2d_operation.py index e323b986..faefd135 100644 --- a/python/cutlass/backend/conv2d_operation.py +++ b/python/cutlass/backend/conv2d_operation.py @@ -1,6 +1,6 @@ -################################################################################ +################################################################################################# # -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -28,7 +28,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -################################################################################ +################################################################################################# import ctypes from typing import Union diff --git a/python/cutlass/backend/epilogue.py b/python/cutlass/backend/epilogue.py index 784f8e95..214a0942 100644 --- a/python/cutlass/backend/epilogue.py +++ b/python/cutlass/backend/epilogue.py @@ -1,6 +1,6 @@ -################################################################################ +################################################################################################# # -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -28,7 +28,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -################################################################################ +################################################################################################# import ctypes diff --git a/python/cutlass/backend/evt/__init__.py b/python/cutlass/backend/evt/__init__.py index 6c82b71a..a7cad2ea 100644 --- a/python/cutlass/backend/evt/__init__.py +++ b/python/cutlass/backend/evt/__init__.py @@ -1,6 +1,6 @@ -################################################################################ +################################################################################################# # -# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -28,7 +28,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -################################################################################ +################################################################################################# from cutlass.backend.evt.epilogue import EpilogueFunctorVisitor from cutlass.backend.evt.frontend import PythonASTFrontend diff --git a/python/cutlass/backend/evt/epilogue.py b/python/cutlass/backend/evt/epilogue.py index b555deb7..c0c780be 100644 --- a/python/cutlass/backend/evt/epilogue.py +++ b/python/cutlass/backend/evt/epilogue.py @@ -1,6 +1,6 @@ -################################################################################ +################################################################################################# # -# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -28,7 +28,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -################################################################################ +################################################################################################# """ Epilogue Visitor interface for compiling, and running visitor-based epilogue. diff --git a/python/cutlass/backend/evt/ir/compute_nodes.py b/python/cutlass/backend/evt/ir/compute_nodes.py index 21592955..783d7cf1 100644 --- a/python/cutlass/backend/evt/ir/compute_nodes.py +++ b/python/cutlass/backend/evt/ir/compute_nodes.py @@ -1,6 +1,6 @@ -################################################################################ +################################################################################################# # -# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -28,7 +28,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -################################################################################ +################################################################################################# """ Python registration for compute nodes in EVT diff --git a/python/cutlass/backend/evt/ir/layout_algorithm.py b/python/cutlass/backend/evt/ir/layout_algorithm.py index 3da35b8d..dd990303 100644 --- a/python/cutlass/backend/evt/ir/layout_algorithm.py +++ b/python/cutlass/backend/evt/ir/layout_algorithm.py @@ -1,6 +1,6 @@ -################################################################################ +################################################################################################# # -# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -28,7 +28,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -################################################################################ +################################################################################################# """ Layout algebras diff --git a/python/cutlass/backend/frontend.py b/python/cutlass/backend/frontend.py index a39635fa..2b907cc7 100644 --- a/python/cutlass/backend/frontend.py +++ b/python/cutlass/backend/frontend.py @@ -1,6 +1,6 @@ -################################################################################ +################################################################################################# # -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -28,7 +28,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -################################################################################ +################################################################################################# from cuda import cuda import numpy as np diff --git a/python/cutlass/backend/gemm_operation.py b/python/cutlass/backend/gemm_operation.py index 85b64f29..2749fe17 100644 --- a/python/cutlass/backend/gemm_operation.py +++ b/python/cutlass/backend/gemm_operation.py @@ -1,4 +1,4 @@ -################################################################################ +################################################################################################# # # Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause @@ -28,7 +28,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -################################################################################ +################################################################################################# import copy import ctypes @@ -712,6 +712,8 @@ def __init__(self, operation, problem_sizes, A, B, C, D, **kwargs): self.gemm_arguments = [] + self.stream = kwargs.get("stream", cuda.CUstream(0)) + # Process the input arguments for idx, problem_size in enumerate(problem_sizes): M, N, K = problem_size.m, problem_size.n, problem_size.k @@ -771,11 +773,6 @@ def __init__(self, operation, problem_sizes, A, B, C, D, **kwargs): self.output_op = kwargs["output_op"] else: self.output_op = self.operation.epilogue_type(1.0, 0.0) - - if "stream" in kwargs.keys(): - self.stream = kwargs["stream"] - else: - self.stream = cuda.CUstream(0) # Get host problem size self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0] diff --git a/python/cutlass/backend/operation.py b/python/cutlass/backend/operation.py index 426e721f..568c1f69 100644 --- a/python/cutlass/backend/operation.py +++ b/python/cutlass/backend/operation.py @@ -1,6 +1,6 @@ -################################################################################ +################################################################################################# # -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -28,7 +28,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -################################################################################ +################################################################################################# import ctypes diff --git a/python/cutlass/emit/pytorch.py b/python/cutlass/emit/pytorch.py index 91a7f94a..73cdaadc 100644 --- a/python/cutlass/emit/pytorch.py +++ b/python/cutlass/emit/pytorch.py @@ -657,7 +657,10 @@ def __exit__(self, exc_type, exc_val, traceback): """ Restores the old value of TORCH_CUDA_ARCH_LIST """ - os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list + if self.old_arch_list is None: + del os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] + else: + os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list def _jit(name: str, cc: int, cpp_file: str, cuda_file: str): diff --git a/python/cutlass/op/conv.py b/python/cutlass/op/conv.py index e2c4389d..0c0d626d 100644 --- a/python/cutlass/op/conv.py +++ b/python/cutlass/op/conv.py @@ -112,6 +112,7 @@ args.sync() """ +from cuda import cuda from cutlass_library import ( ConvKind, ConvMode, @@ -131,7 +132,6 @@ from cutlass.op.op import OperationBase from cutlass.shape import Conv2DProblemSize, MatrixCoord from cutlass.utils import check, datatypes -from cuda import cuda class Conv2d(OperationBase): diff --git a/python/cutlass/op/gemm.py b/python/cutlass/op/gemm.py index 61e4f6a8..38c06a3b 100644 --- a/python/cutlass/op/gemm.py +++ b/python/cutlass/op/gemm.py @@ -116,6 +116,7 @@ from math import prod +from cuda import cuda from cutlass_library import ( DataType, DataTypeSize, @@ -131,7 +132,6 @@ from cutlass.op.op import OperationBase from cutlass.shape import GemmCoord from cutlass.utils import check, datatypes -from cuda import cuda class Gemm(OperationBase): @@ -691,6 +691,7 @@ def run(self, A=None, B=None, C=None, D=None, 'D': self._get_batch_stride(D) } } + kwargs['stream'] = stream if isinstance(self.epilogue_functor, EpilogueFunctorVisitor): diff --git a/python/cutlass/op/gemm_grouped.py b/python/cutlass/op/gemm_grouped.py index 162e0493..34dfcac4 100644 --- a/python/cutlass/op/gemm_grouped.py +++ b/python/cutlass/op/gemm_grouped.py @@ -53,6 +53,7 @@ from cutlass_library import DataTypeSize +from cuda import cuda from cutlass.backend.gemm_operation import ( GemmGroupedArguments, GemmOperationGrouped, @@ -65,7 +66,6 @@ from cutlass.op.gemm import Gemm from cutlass.shape import GemmCoord from cutlass.utils import check, datatypes -from cuda import cuda class GroupedGemm(Gemm): diff --git a/python/cutlass/shape.py b/python/cutlass/shape.py index 6e21dbba..37341463 100644 --- a/python/cutlass/shape.py +++ b/python/cutlass/shape.py @@ -1,6 +1,6 @@ -################################################################################ +################################################################################################# # -# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -28,7 +28,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -################################################################################ +################################################################################################# """ Utilities for expressing shapes diff --git a/python/docs_src/source/conf.py b/python/docs_src/source/conf.py index 57cd633d..762dd037 100644 --- a/python/docs_src/source/conf.py +++ b/python/docs_src/source/conf.py @@ -1,3 +1,35 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + # Configuration file for the Sphinx documentation builder. # # For the full list of built-in configuration values, see the documentation: diff --git a/python/setup_cutlass.py b/python/setup_cutlass.py index 7e78a218..5b8c6528 100644 --- a/python/setup_cutlass.py +++ b/python/setup_cutlass.py @@ -17,7 +17,7 @@ # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 'AS IS' +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE diff --git a/python/setup_library.py b/python/setup_library.py index 115e6c0a..dcae3ec3 100644 --- a/python/setup_library.py +++ b/python/setup_library.py @@ -17,7 +17,7 @@ # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 'AS IS' +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE diff --git a/python/setup_pycute.py b/python/setup_pycute.py index bf06967d..3317752a 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -17,7 +17,7 @@ # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 'AS IS' +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE diff --git a/test/python/cutlass/conv2d/conv2d_problem_sizes.py b/test/python/cutlass/conv2d/conv2d_problem_sizes.py index bf164207..502c49a7 100644 --- a/test/python/cutlass/conv2d/conv2d_problem_sizes.py +++ b/test/python/cutlass/conv2d/conv2d_problem_sizes.py @@ -23,7 +23,7 @@ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) HOWEVER +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/test/unit/cute/core/bitfield.cpp b/test/unit/cute/core/bitfield.cpp index 4899e47a..edbfdf44 100644 --- a/test/unit/cute/core/bitfield.cpp +++ b/test/unit/cute/core/bitfield.cpp @@ -38,9 +38,10 @@ #include #include -#include #include +#include + using namespace cute; TEST(CuTe_core, Bitfield) diff --git a/test/unit/cute/core/complement.cpp b/test/unit/cute/core/complement.cpp index cfad54ff..fa628d33 100644 --- a/test/unit/cute/core/complement.cpp +++ b/test/unit/cute/core/complement.cpp @@ -43,26 +43,30 @@ test_complement(Layout const& layout, CoSizeHi const& cosize_hi) auto result = complement(layout, cosize_hi); - CUTLASS_TRACE_HOST("complement( " << layout << ", " << cosize_hi << ") => " << result); + CUTLASS_TRACE_HOST("complement(" << layout << ", " << cosize_hi << ") => " << result); - // Post-condition on the domain size of the complement (1) - EXPECT_GE( size(result), cosize_hi / size(filter(layout))); - // Post-condition on the codomain size of the complement (2) - EXPECT_LE(cosize(result), cute::ceil_div(cosize_hi, cosize(layout)) * cosize(layout)); + auto completed = make_layout(layout, result); + + // Lower-bound on the codomain size of the layout ++ complement (1) + EXPECT_GE(cosize(completed), cosize_hi); + // Upper-bound on the codomain size of the complement (2) + EXPECT_LE(cosize(result), cute::round_up(cosize_hi, cosize(layout))); // Post-condition on the codomain of the complement for (int i = 1; i < size(result); ++i) { EXPECT_LT(result(i-1), result(i)); // Ordered (3) for (int j = 0; j < size(layout); ++j) { - EXPECT_NE(result(i), layout(j)); // Complemented (4) + EXPECT_NE(result(i), layout(j)); // Disjoint (4) } } // Other observations - EXPECT_LE(size(result),cosize(result)); // As a result of the ordered condition (3) - EXPECT_GE(cosize(result), cosize_hi / size(filter(layout))); // As a result of (1) (2) and (5) - if constexpr (is_static::value) { // If we can apply complement again - EXPECT_EQ(size(complement(make_layout(layout,result))), 1); // There's no more codomain left over + EXPECT_LE(size(result), cosize(result)); // As a result of the ordered condition (3) + EXPECT_GE(size(result), cosize_hi / size(filter(layout))); + EXPECT_LE(cosize(completed), cosize(result) + cosize(layout)); + EXPECT_GE(cosize(result), cosize_hi / size(filter(layout))); + if constexpr (is_static::value) { // If we can apply complement again + EXPECT_EQ(size(complement(completed)), 1); // There's no more codomain left over } } @@ -125,6 +129,7 @@ TEST(CuTe_core, Complement) test_complement(layout, Int<1>{}); test_complement(layout); test_complement(layout, Int<16>{}); + test_complement(layout, Int<19>{}); } { @@ -153,6 +158,12 @@ TEST(CuTe_core, Complement) test_complement(layout); } + { + auto layout = Layout, Stride<_1,_6>>{}; + + test_complement(layout); + } + { auto layout = Layout, Stride<_8,_1,_64>>{}; @@ -167,26 +178,34 @@ TEST(CuTe_core, Complement) } { - auto layout = make_layout(Shape,Shape<_2, _2>>{}, + auto layout = make_layout(Shape ,Shape <_2, _2>>{}, Stride,Stride<_8,_32>>{}); test_complement(layout); } { - auto layout = make_layout(Shape,Shape<_2, _2>>{}, + auto layout = make_layout(Shape ,Shape <_2,_2>>{}, Stride,Stride<_8,_4>>{}); test_complement(layout); } - // Fails due to non-injective input - //{ - //auto layout = make_layout(Shape,Shape<_2, _2>>{}, + // Fails due to non-injective layout + // { + // auto layout = make_layout(Shape,Shape<_2, _2>>{}, // Stride,Stride<_8,_4>>{}); - //test_complement(layout); - //} + // test_complement(layout); + // } + + // Fails due to non-injective layout + // { + // auto layout = Layout, Stride<_2,_3>>{}; + + // test_complement(layout); + // test_complement(layout, Int<19>{}); + // } { auto layout = Layout, Stride<_1,_6>>{}; diff --git a/test/unit/cute/core/composition.cpp b/test/unit/cute/core/composition.cpp index 7934b3ce..023e992e 100644 --- a/test/unit/cute/core/composition.cpp +++ b/test/unit/cute/core/composition.cpp @@ -42,8 +42,8 @@ using namespace cute; template void -test_composition(const LayoutA& layoutA, - const LayoutB& layoutB) +test_composition(LayoutA const& layoutA, + LayoutB const& layoutB) { auto layoutR = composition(layoutA, layoutB); @@ -52,14 +52,12 @@ test_composition(const LayoutA& layoutA, CUTLASS_TRACE_HOST(" => "); CUTLASS_TRACE_HOST(layoutR); - // Test that layout R is compatible with layout B + // Test that layout B is compatible with layout R EXPECT_TRUE(compatible(layoutB, layoutR)); - // True post-condition: Every coordinate c of layoutB with L1D(c) < size(layoutR) is a coordinate of layoutR. - - // Test that R(c) = A(B(c)) for all coordinates c in layoutR - for (int i = 0; i < size(layoutR); ++i) { - EXPECT_EQ(layoutR(i), layoutA(layoutB(i))); + // Test that R(c) = A(B(c)) for all coordinates c in layoutB + for (int c = 0; c < size(layoutB); ++c) { + EXPECT_EQ(layoutR(c), layoutA(layoutB(c))); } } diff --git a/test/unit/cute/core/logical_divide.cpp b/test/unit/cute/core/logical_divide.cpp index 5d37b829..840bb7f9 100644 --- a/test/unit/cute/core/logical_divide.cpp +++ b/test/unit/cute/core/logical_divide.cpp @@ -45,10 +45,10 @@ test_logical_divide(LayoutA const& layoutA, auto layoutR = logical_divide(layoutA, layoutB); CUTLASS_TRACE_HOST("test_logical_divide()"); - CUTLASS_TRACE_HOST(shape(layoutA) << " / " << shape(layoutB) << " => " << shape(layoutR) ); + CUTLASS_TRACE_HOST( shape(layoutA) << " / " << shape(layoutB) << " => " << shape(layoutR)); CUTLASS_TRACE_HOST(stride(layoutA) << " " << stride(layoutB) << " => " << stride(layoutR)); - // Test that layout R is compatible with layout B + // Test that layout B is compatible with layout R_0 ASSERT_EQ(rank(layoutR), 2); ASSERT_TRUE(compatible(layoutB, layout<0>(layoutR))); } @@ -186,10 +186,10 @@ TEST(CuTe_core, Logical_divide) // Enforcement for dynamic cases auto result = logical_divide(layout, tile); - static_assert(decltype(shape<0>(result) == Int<32>{})::value); - static_assert(decltype(stride<0>(result) == Int<1>{})::value); - assert(shape<1>(result) == 1); - static_assert(decltype(stride<1>(result) == Int<32>{})::value); + ASSERT_TRUE(decltype(shape<0>(result) == Int<32>{})::value); + ASSERT_TRUE(decltype(stride<0>(result) == Int<1>{})::value); + ASSERT_TRUE(shape<1>(result) == 1); + ASSERT_TRUE(decltype(stride<1>(result) == Int<32>{})::value); } { @@ -200,10 +200,10 @@ TEST(CuTe_core, Logical_divide) // Enforcement for dynamic cases auto result = logical_divide(layout, tile); - static_assert(decltype(shape<0>(result) == Int<32>{})::value); - static_assert(decltype(stride<0>(result) == Int<1>{})::value); - assert(shape<1>(result) == 2); - static_assert(decltype(stride<1>(result) == Int<32>{})::value); + ASSERT_TRUE(decltype(shape<0>(result) == Int<32>{})::value); + ASSERT_TRUE(decltype(stride<0>(result) == Int<1>{})::value); + ASSERT_TRUE(shape<1>(result) == 2); + ASSERT_TRUE(decltype(stride<1>(result) == Int<32>{})::value); } { @@ -221,10 +221,10 @@ TEST(CuTe_core, Logical_divide) // Enforcement for dynamic cases auto result = logical_divide(layout, tile); - static_assert(decltype(shape<0>(result) == Int<48>{})::value); - static_assert(decltype(stride<0>(result) == Int<1>{})::value); - assert(shape<1>(result) == 1); - static_assert(decltype(stride<1>(result) == Int<48>{})::value); + ASSERT_TRUE(decltype(shape<0>(result) == Int<48>{})::value); + ASSERT_TRUE(decltype(stride<0>(result) == Int<1>{})::value); + ASSERT_TRUE(shape<1>(result) == 1); + ASSERT_TRUE(decltype(stride<1>(result) == Int<48>{})::value); } // DISALLOWED diff --git a/test/unit/cute/core/logical_product.cpp b/test/unit/cute/core/logical_product.cpp index bcdae4ea..d812743c 100644 --- a/test/unit/cute/core/logical_product.cpp +++ b/test/unit/cute/core/logical_product.cpp @@ -46,13 +46,9 @@ test_logical_product(LayoutA const& layoutA, CUTLASS_TRACE_HOST(shape(layoutA) << " x " << shape(layoutB) << " => " << shape(layoutR) ); CUTLASS_TRACE_HOST(stride(layoutA) << " " << stride(layoutB) << " => " << stride(layoutR)); - // Test that layout R is compatible with layout B ASSERT_EQ(rank(layoutR), 2); - //assert(compatible(layoutB, layout<0>(layoutR))); - //assert(consistent(layoutA, layout<1>(layoutR))); - - // True post-condition: - + ASSERT_TRUE(layoutA == layout<0>(layoutR)); + ASSERT_TRUE(compatible(layoutB, layout<1>(layoutR))); } TEST(CuTe_core, Logical_product) diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index 394a66ae..9f4b8e0e 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -37,6 +37,8 @@ #include #include #include +#include +#include #include "../../common/cutlass_unit_test.h" @@ -49,16 +51,16 @@ #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/complex.h" + #include "testbed_utils.h" #include "cutlass/kernel_hardware_info.hpp" #include "cutlass/layout/matrix.h" #include "cutlass/matrix_coord.h" #include "cutlass/gemm/gemm.h" -#include "cutlass/fast_math.h" -#include "cutlass/platform/platform.h" -#include "cutlass/epilogue/fusion/operations.hpp" -#include "cutlass/gemm/kernel/tile_scheduler_params.h" #include "cute/int_tuple.hpp" #include "cute/layout.hpp" @@ -69,6 +71,21 @@ namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// +enum class ScalarLoc { + ON_HOST = 0, + ON_DEVICE = 1 +}; + +enum class VectorBeta { + DISABLED = 0, + ENABLED = 1 +}; + +enum class CheckEquality { + EXACT = 0, + RELATIVE = 1 +}; + namespace detail{ // Helper classes that take default data type when @@ -95,7 +112,48 @@ struct ElementScalarType && + !std::is_same_v)) > + explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} + explicit operator int() const { return max_swizzle_size_; } +private: + int max_swizzle_size_ = 1; +}; + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +template +struct IsDefaultEpilogue { + static constexpr bool value = false; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; // The number of splits to test. // @@ -141,209 +199,124 @@ class Iterations { int iterations_ = 20; }; -// The maxium swizzle size to use -// -// This class, like Splits above makes it harder to confuse -// the order of arguments of the various run(...) functions in this file. -class MaxSwizzleSize { -public: - MaxSwizzleSize() = default; +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { - template && - !std::is_same_v)) > - explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} - explicit operator int() const { return max_swizzle_size_; } -private: - int max_swizzle_size_ = 1; -}; + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; -template -auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (is_subbyte_v) { - return subbyte_iterator(ptr); + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else{ + scope_max = 5; + scope_min = -5; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { - return ptr; + EXPECT_TRUE(false) << "Not implemented"; + return false; } + + return true; } -template < - typename Gemm, - template class ActivationFunctor_ = cutlass::epilogue::thread::Identity -> -struct TestbedImpl { +// Looks at Cute Stride to check Row / Column Major +template +static constexpr bool is_row_or_col_major(){ + int stride_0 = int(cute::size<0>(Stride{})); + int stride_1 = int(cute::size<1>(Stride{})); + int depth = cute::depth(Stride{}); + return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); +} + + +// +// Default MMA input Operands : A , B +// +template +struct HostCollectiveMainloop { // Kernel data types using ElementA = typename Gemm::GemmKernel::ElementA; using StrideA = typename Gemm::GemmKernel::StrideA; using ElementB = typename Gemm::GemmKernel::ElementB; using StrideB = typename Gemm::GemmKernel::StrideB; - using ElementC = std::conditional_t, - typename Gemm::GemmKernel::ElementD,typename Gemm::GemmKernel::ElementC>; - using StrideC = typename Gemm::GemmKernel::StrideC; - using ElementD = typename Gemm::GemmKernel::ElementD; - using StrideD = typename Gemm::GemmKernel::StrideD; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; - /// For custom EVTs - using ElementCompute = typename ElementComputeType::Type; - using ElementScalar = typename ElementScalarType::Type; - using ActivationFunctor = ActivationFunctor_; - - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - - static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - - static constexpr uint32_t mma_promotion_interval = 4; - - // Looks at Cute Stride to check Row / Column Major - template - static constexpr bool is_row_or_col_major(){ - int stride_0 = int(cute::size<0>(Stride{})); - int stride_1 = int(cute::size<1>(Stride{})); - int depth = cute::depth(Stride{}); - return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); - } - // Note: this limitation comes from testbed / not the library - static_assert(is_row_or_col_major(), - "ERROR : A Layout is neither Row / Column Major)"); - static_assert(is_row_or_col_major(), - "ERROR : B Layout is neither Row / Column Major)"); - static_assert(is_row_or_col_major(), - "ERROR : C Layout is neither Row / Column Major)"); - static_assert(is_row_or_col_major(), - "ERROR : D Layout is neither Row / Column Major)"); + using Arguments = typename Gemm::GemmKernel::MainloopArguments; - // Deduce Cutlass Layouts (RowMajor & ColumnMajor) - using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; - using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; - using LayoutTagC = cutlass::detail::StrideToLayoutTagA_t; - using LayoutTagD = cutlass::detail::StrideToLayoutTagA_t; - using LayoutTagVector = cutlass::layout::PackedVectorLayout; + cutlass::ComplexTransform TransformA = Gemm::kTransformA; + cutlass::ComplexTransform TransformB = Gemm::kTransformB; - /// Initialization StrideA stride_a; StrideB stride_b; - StrideC stride_c; - StrideD stride_d; + typename LayoutTagA::Stride stride_factor_A; typename LayoutTagB::Stride stride_factor_B; - typename LayoutTagC::Stride stride_factor_C; - typename LayoutTagD::Stride stride_factor_D; + cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - static constexpr uint64_t kDefaultSeed = 4096; cutlass::HostTensor tensor_A; cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D; - cutlass::HostTensor reference_D; - uint32_t sm_count; - - // Used to force multi-wave tests for persistent kernel schedules - constexpr static int MaxSmCount = 16; - - cutlass::ComplexTransform TransformA = Gemm::kTransformA; - cutlass::ComplexTransform TransformB = Gemm::kTransformB; - // - // Methods - // + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; - TestbedImpl( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed - ): - stride_factor_A(typename LayoutTagA::Stride()), - stride_factor_B(typename LayoutTagB::Stride()), - stride_factor_C(typename LayoutTagC::Stride()), - stride_factor_D(typename LayoutTagD::Stride()), - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); - TestbedImpl( - typename LayoutTagA::Stride stride_factor_A_, - typename LayoutTagB::Stride stride_factor_B_, - typename LayoutTagC::Stride stride_factor_C_, - typename LayoutTagD::Stride stride_factor_D_, + HostCollectiveMainloop( cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() ): stride_factor_A(stride_factor_A_), stride_factor_B(stride_factor_B_), - stride_factor_C(stride_factor_C_), - stride_factor_D(stride_factor_D_), - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } - else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } - else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } - else { - scope_max = 8; - scope_min = -8; - } - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - - else if (dist_kind == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(view); - } - - else if (dist_kind == cutlass::Distribution::Gaussian) { - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - - else if (dist_kind == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - - else if (dist_kind == cutlass::Distribution::AllOnes) { - cutlass::reference::host::TensorFill(view, Element(1)); - } - - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } + init_A(init_A_), init_B(init_B_), seed(seed_) { } - return true; - } - - /// Initializes data structures + template void initialize(ProblemShapeType problem_size) { // // Allocate the GEMM workspace @@ -356,12 +329,9 @@ struct TestbedImpl { stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); - stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode auto a_coord = cutlass::make_Coord(M * L, K); - auto c_coord = cutlass::make_Coord(M * L, N); // Cutlass has Row/Col major refers to MxK times KxN matrix product, // so the HostTensorB should be treated as KxN in "coord"'s view auto b_coord = cutlass::make_Coord(K, N * L); @@ -369,426 +339,319 @@ struct TestbedImpl { tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); - tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); - tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); - reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); - + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); - EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2020)); // It is possible to randomly initialize to all zeros, so override this with non-zeros // in the upper left corner of each operand. tensor_A.host_view().at({0, 0}) = ElementA(1); tensor_B.host_view().at({0, 0}) = ElementB(1); - tensor_C.host_view().at({0, 0}) = ElementC(1); - - cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); tensor_A.sync_device(); tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D.sync_device(); } - /// Compares computed reference with device reference and outputs to a file if incorrect + Arguments to_args() { + return { + tensor_A.device_data(), stride_a, + tensor_B.device_data(), stride_b + }; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B, TransformA, TransformB}; + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view(); + } + bool compare_reference( - cute::Shape problem_shape_MNKL, - ElementScalar alpha, - ElementScalar beta) - { + cute::Shape problem_shape_MNKL) { auto [M, N, K, L] = problem_shape_MNKL; - tensor_D.sync_host(); EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + return true; + } +}; - if (tensor_D.size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - } - if (reference_D.size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - } +template +struct HostCollectiveDefaultEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; - bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; - EXPECT_TRUE(passed); - if (!passed) { - std::stringstream fname; - fname << "error_Gemm_device_" - << M << "x" << N << "x" << K << "x" << L << "_" - << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" - << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" - << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; - std::ofstream file(fname.str()); - file - << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L - << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + using FusionOp = typename Gemm::EpilogueOutputOp; - file - << "A =\n" << tensor_A.host_view() - << "\nB =\n" << tensor_B.host_view() - << "\nC =\n" << tensor_C.host_view() - << "\n\nReference =\n" << reference_D.host_view() - << "\n\nComputed =\n" << tensor_D.host_view(); - } + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - return passed; - } + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); - /// Verifies the result is a GEMM - bool verify( - ProblemShapeType problem_size, - ElementScalar alpha, - ElementScalar beta) - { - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::size<0>(problem_shape_MNKL); - auto N = cute::size<1>(problem_shape_MNKL); - auto K = cute::size<2>(problem_shape_MNKL); - auto L = cute::size<3>(problem_shape_MNKL); + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; - auto A = cute::make_tensor(detail::make_iterator(tensor_A.host_data()), - cute::make_layout(cute::make_shape(M, K, L), stride_a)); - auto B = cute::make_tensor(detail::make_iterator(tensor_B.host_data()), - cute::make_layout(cute::make_shape(N, K, L), stride_b)); - auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()), - cute::make_layout(cute::make_shape(M, N, L), stride_c)); - auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), - cute::make_layout(cute::make_shape(M, N, L), stride_d)); - auto Bias = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, cute::_1{}))); - auto Aux = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, N, L), stride_d)); - auto Valpha = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, cute::_1{}))); - auto Vbeta = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, cute::_1{}))); + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; - cutlass::reference::host::GettMainloopParams mainloop_params{A, B, TransformA, TransformB}; + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; - cutlass::reference::host::GettEpilogueParams< - ElementScalar, - ElementScalar, - ElementAccumulator, - ElementCompute, - decltype(C), - decltype(D), - decltype(Bias), - decltype(Aux), - decltype(Valpha), - decltype(Vbeta), - ActivationFunctor - > - epilogue_params{ - alpha, beta, - C, D, Bias, Aux - , Valpha, Vbeta - }; + /// Initialization + StrideC stride_c; + StrideD stride_d; - cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); - return compare_reference(problem_shape_MNKL, alpha, beta); - } + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; - /// Determine if the CUDA device is sufficient to run the kernel - bool sufficient() { - // - // Determine SMEM requirements and waive if not satisfied - // + cutlass::HostTensor tensor_C; + // Inputs + ElementScalar alpha; + ElementScalar beta; - int smem_size = Gemm::GemmKernel::SharedStorageSize; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is true, beta is passed as a host scalar instead of device vector + VectorBeta disable_vector_beta = VectorBeta::DISABLED; - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; - cudaDeviceProp properties; - result = cudaGetDeviceProperties(&properties, device_idx); - this->sm_count = properties.multiProcessorCount; + HostCollectiveDefaultEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } + void initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + // Initialize Epilogue tensors + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } + stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - return true; - } + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M * L, N); + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2020)); + tensor_C.host_view().at({0, 0}) = ElementC(1); - bool profile( - ProblemShapeType problem_size, - int iterations, - Gemm& gemm_op, - typename Gemm::Arguments& arguments, - cutlass::device_memory::allocation& workspace) { - int M = cute::size<0>(problem_size); - int N = cute::size<1>(problem_size); - int K = cute::size<2>(problem_size); - int L = 1; - if constexpr(cute::rank(ProblemShapeType{}) == 4) { - L = cute::size<3>(problem_size); - } + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + tensor_C.sync_device(); + tensor_D.sync_device(); + alpha = alpha_; + beta = beta_; + } - cutlass::Status status; - // - // Run the GEMM - // - cudaError_t result; + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { - for (int iter = 0; iter < iterations; ++iter) { - status = gemm_op(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - return false; + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(0.1f); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); } } - - result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; - return false; + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); } - - return true; } - /// Executes one test - bool run( - ProblemShapeType problem_size, - ElementScalar alpha = ElementScalar(1), - ElementScalar beta = ElementScalar(0), - bool profiling = false, - detail::Iterations iterations = detail::Iterations{}, - RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, - detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, - detail::Splits splits = detail::Splits{}, - DecompositionMode decomposition_mode = DecompositionMode::Heuristic) - { - // Fail test if insufficient CUDA device - if (!sufficient()) { - std::cout << "Test failed due to insufficient CUDA device." << std::endl; - return false; - } + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) { + auto [M, N, K, L] = problem_shape_MNKL; - this->initialize(problem_size); + tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - // - // Initialize the GEMM operator - // + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + } - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - if (not profiling) { - this->sm_count = std::min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); - hw_info.sm_count = this->sm_count; + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); } - else { - this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - hw_info.sm_count = this->sm_count; + + bool passed = equality_check(reference_D.host_view(), tensor_D.host_view()); + if(!passed) { + std::cout<<"D is incorrect"<) { - arguments.scheduler.splits = static_cast(splits); - arguments.scheduler.max_swizzle_size = static_cast(max_swizzle); - arguments.scheduler.raster_order = raster_order; - arguments.scheduler.decomposition_mode = decomposition_mode; - - } else { - arguments.scheduler.max_swizzle_size = static_cast(max_swizzle); - arguments.scheduler.raster_order = raster_order; - } - - Gemm gemm_op; - - size_t workspace_size = Gemm::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = gemm_op.can_implement(arguments); + }; - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; - return true; - } + return arguments; + } + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; // - // Run the GEMM + // Allocate the GEMM workspace // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); - if (profiling) { - return profile(problem_size, static_cast(iterations), gemm_op, arguments, workspace); - } - else { - cudaError_t result; - status = gemm_op.initialize(arguments, workspace.get()); - status = gemm_op.run(); - result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; - return false; - } - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D)> + epilogue_params{}; - // - // Verify - // - bool passed = this->verify(problem_size, alpha, beta); - if (!passed) { - std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta - << "\n"; - } + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha; + epilogue_params.beta = beta; - return passed; - } + return epilogue_params; } }; -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Gemm, - template class ActivationFunctor -> -struct Testbed3x { - - using TestBedImpl = typename detail::TestbedImpl; - using Kernel = typename Gemm::GemmKernel; - using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; - - using ElementAccumulator = typename TestBedImpl::ElementAccumulator; - using ElementCompute = typename TestBedImpl::ElementCompute; - using ElementScalar = typename TestBedImpl::ElementScalar; - - using LayoutTagA = typename TestBedImpl::LayoutTagA; - using LayoutTagB = typename TestBedImpl::LayoutTagB; - using LayoutTagC = typename TestBedImpl::LayoutTagC; - using LayoutTagD = typename TestBedImpl::LayoutTagD; - - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - - // Detail Implementation - TestBedImpl impl_; - - // - // Methods - // - Testbed3x( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImpl::kDefaultSeed) - : impl_(init_A_, init_B_, init_C_, seed_) {} - - Testbed3x( - typename LayoutTagA::Stride stride_factor_A_, - typename LayoutTagB::Stride stride_factor_B_, - typename LayoutTagC::Stride stride_factor_C_, - typename LayoutTagD::Stride stride_factor_D_, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImpl::kDefaultSeed) - : impl_(stride_factor_A_, - stride_factor_B_, - stride_factor_C_, - stride_factor_D_, - init_A_, - init_B_, - init_C_, - seed_) {} +template +struct HostCollectiveEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; - /// Executes one test - bool run( - typename TestBedImpl::ProblemShapeType problem_size, - ElementScalar alpha = ElementScalar(1), - ElementScalar beta = ElementScalar(0), - RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, - detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, - detail::Splits splits = detail::Splits{}, - DecompositionMode decomposition_mode = DecompositionMode::Heuristic, - bool profiling = false, - detail::Iterations iterations = detail::Iterations{}) - { - return impl_.run( - problem_size, alpha, beta, profiling, iterations, raster_order, max_swizzle, splits, decomposition_mode - ); - } -}; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + static_assert(IsDefaultEpilogue::value == false, "Default Epilogue is not supported"); -///////////////////////////////////////////////////////////////////////////////////////////////// + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; -// Testbed for GEMMs with fused epilogues using the fusion::FusionOperation API -// Does not support testing of custom EVTs -template -struct Testbed3xFusionOperation { + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - using TestBedImpl = typename detail::TestbedImpl; - using Kernel = typename Gemm::GemmKernel; - using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); - using LayoutTagA = typename TestBedImpl::LayoutTagA; - using LayoutTagB = typename TestBedImpl::LayoutTagB; - using LayoutTagC = typename TestBedImpl::LayoutTagC; - using LayoutTagD = typename TestBedImpl::LayoutTagD; + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors using LayoutTagVector = cutlass::layout::PackedVectorLayout; - using ElementA = typename Kernel::ElementA; - using StrideA = typename Kernel::StrideA; - using ElementB = typename Kernel::ElementB; - using StrideB = typename Kernel::StrideB; - using ElementC = typename Kernel::ElementC; - using StrideC = typename Kernel::StrideC; - using ElementD = typename Kernel::ElementD; - using StrideD = typename Kernel::StrideD; - using ProblemShapeType = typename Kernel::ProblemShape; - using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; // // FusionOperation derived types/queries // + using EpiloguePolicy = typename Epilogue::DispatchPolicy; + static constexpr bool IsLegacy = + cute::is_same_v< + EpiloguePolicy, + cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< + EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize> + >; + using FusionOp = typename Gemm::EpilogueOutputOp; static_assert(cute::is_base_of_v); - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - - // fusion types are potentially void if the fusion is not supported - // helper so we don't try to construct HostTensor with void type - template - using non_void_t = cute::conditional_t, U, T>; - using ElementCompute = typename FusionOp::ElementCompute; using ElementScalar = typename FusionOp::ElementScalar; using ElementBias = non_void_t; @@ -810,14 +673,15 @@ struct Testbed3xFusionOperation { static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && (cute::is_same_v || cute::is_same_v); - // Legacy support for deprecated bias-elementwise collective, will be removed next release - using EpiloguePolicy = typename Epilogue::DispatchPolicy; - static constexpr bool IsLegacy = - cute::is_same_v< - EpiloguePolicy, - cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< - EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize> - >; + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + StrideC stride_c; + StrideD stride_d; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; // Inputs cutlass::HostTensor alpha; @@ -828,11 +692,15 @@ struct Testbed3xFusionOperation { cutlass::HostTensor scale_D; cutlass::HostTensor scale_Aux; cutlass::HostTensor bias; + cutlass::HostTensor tensor_C; + // Outputs cutlass::HostTensor abs_max_Aux; cutlass::HostTensor abs_max_D; cutlass::HostTensor tensor_Aux; cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; // References cutlass::HostTensor reference_dbias; @@ -840,81 +708,73 @@ struct Testbed3xFusionOperation { cutlass::HostTensor reference_abs_max_Aux; cutlass::HostTensor reference_abs_max_D; - // Detail Implementation - TestBedImpl impl_; - // Whether to use relative equality checks - bool check_relative_equality = false; + CheckEquality check_relative_equality = CheckEquality::EXACT; // Are scalars copied to device memory before kernel launch - bool use_device_scalars = false; + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; // If per-row scale is enabled and this is true, beta is passed as a host scalar instead of device vector - bool disable_vector_beta = false; + VectorBeta disable_vector_beta = VectorBeta::DISABLED; + // Random distribution with which to initialize the A/B/C/D/Aux scaling factors cutlass::Distribution::Kind init_scale = cutlass::Distribution::Uniform; // Random distribution with which to initialize the bias vector cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; - // - // Methods - // - Testbed3xFusionOperation( - bool check_relative_equality_ = false, - bool use_device_scalars_ = false, - bool disable_vector_beta_ = false, + HostCollectiveEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImpl::kDefaultSeed - ) : impl_(init_A_, init_B_, init_C_, seed_), - check_relative_equality(check_relative_equality_), - use_device_scalars(use_device_scalars_), - init_scale(init_scale_), init_bias(init_bias_) { } - - Testbed3xFusionOperation( - typename LayoutTagA::Stride stride_factor_A_, - typename LayoutTagB::Stride stride_factor_B_, - typename LayoutTagC::Stride stride_factor_C_, - typename LayoutTagD::Stride stride_factor_D_, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImpl::kDefaultSeed - ) : impl_(stride_factor_A_, - stride_factor_B_, - stride_factor_C_, - stride_factor_D_, - init_A_, - init_B_, - init_C_, - seed_) { } + uint64_t seed_ = kDefaultSeed + ): init_scale(init_scale_), init_bias(init_bias_), + init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } - /// Initializes data structures void initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + // Initialize Epilogue tensors auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; - auto scalar_coord = cutlass::make_Coord(1); - auto col_vector_coord = cutlass::make_Coord(M); - // Allocate the GEMM workspace for A/B/C/D tensor - impl_.initialize(problem_size); + stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M * L, N); + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2020)); + tensor_C.host_view().at({0, 0}) = ElementC(1); + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + tensor_C.sync_device(); + tensor_D.sync_device(); + + auto scalar_coord = cutlass::make_Coord(1); + auto col_vector_coord = cutlass::make_Coord(M); if constexpr (IsPerRowScaleEnabled) { alpha.resize(col_vector_coord); - EXPECT_TRUE(impl_.initialize_tensor(alpha.host_view(), init_scale, impl_.seed + 2023)); - if (disable_vector_beta) { + EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); + if (disable_vector_beta == VectorBeta::DISABLED) { beta.resize(scalar_coord, false); cutlass::reference::host::TensorFill(beta.host_view(), beta_); } else { beta.resize(col_vector_coord); - EXPECT_TRUE(impl_.initialize_tensor(beta.host_view(), init_scale, impl_.seed + 2024)); + EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); } } else { - alpha.resize(scalar_coord, use_device_scalars); - beta.resize(scalar_coord, use_device_scalars); + alpha.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + beta.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); cutlass::reference::host::TensorFill(beta.host_view(), beta_); } @@ -922,14 +782,14 @@ struct Testbed3xFusionOperation { beta.sync_device(); if constexpr (IsScaleFactorEnabled) { - scale_A.resize(scalar_coord, use_device_scalars); - scale_B.resize(scalar_coord, use_device_scalars); - scale_C.resize(scalar_coord, use_device_scalars); - scale_D.resize(scalar_coord, use_device_scalars); - EXPECT_TRUE(impl_.initialize_tensor(scale_A.host_view(), init_scale, impl_.seed + 2023)); - EXPECT_TRUE(impl_.initialize_tensor(scale_B.host_view(), init_scale, impl_.seed + 2024)); - EXPECT_TRUE(impl_.initialize_tensor(scale_C.host_view(), init_scale, impl_.seed + 2025)); - EXPECT_TRUE(impl_.initialize_tensor(scale_D.host_view(), init_scale, impl_.seed + 2026)); + scale_A.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_B.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_C.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_D.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_A.host_view(), init_scale, seed + 2023)); + EXPECT_TRUE(initialize_tensor(scale_B.host_view(), init_scale, seed + 2024)); + EXPECT_TRUE(initialize_tensor(scale_C.host_view(), init_scale, seed + 2025)); + EXPECT_TRUE(initialize_tensor(scale_D.host_view(), init_scale, seed + 2026)); scale_A.sync_device(); scale_B.sync_device(); scale_C.sync_device(); @@ -938,7 +798,7 @@ struct Testbed3xFusionOperation { if constexpr (IsBiasEnabled) { bias.resize(col_vector_coord); - EXPECT_TRUE(impl_.initialize_tensor(bias.host_view(), init_bias, impl_.seed + 2023)); + EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); bias.sync_device(); } @@ -964,7 +824,7 @@ struct Testbed3xFusionOperation { auto aux_coord = cutlass::make_Coord(M * L, N); auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); tensor_Aux.resize(aux_coord, aux_layout); - EXPECT_TRUE(impl_.initialize_tensor(tensor_Aux.host_view(), impl_.init_C, impl_.seed + 2023)); + EXPECT_TRUE(initialize_tensor(tensor_Aux.host_view(), init_C, seed + 2023)); tensor_Aux.sync_device(); stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, L)); } @@ -978,8 +838,8 @@ struct Testbed3xFusionOperation { stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, L)); if constexpr (IsScaleFactorEnabled) { - scale_Aux.resize(scalar_coord, use_device_scalars); - EXPECT_TRUE(impl_.initialize_tensor(scale_Aux.host_view(), init_scale, impl_.seed + 2027)); + scale_Aux.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_Aux.host_view(), init_scale, seed + 2027)); scale_Aux.sync_device(); } @@ -993,6 +853,7 @@ struct Testbed3xFusionOperation { cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); } } + } template < @@ -1010,151 +871,235 @@ struct Testbed3xFusionOperation { Element epsilon(0.1f); Element nonzero_floor(std::numeric_limits::min()); - if (check_relative_equality) { - return cutlass::reference::host::TensorRelativelyEquals( - lhs, rhs, epsilon, nonzero_floor); + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } } else { return cutlass::reference::host::TensorEquals(lhs, rhs); } } - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference(cute::Shape problem_shape_MNKL) { - + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) { auto [M, N, K, L] = problem_shape_MNKL; + + tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + } + + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + } + + bool passed = equality_check(reference_D.host_view(), tensor_D.host_view()); + if(!passed) { + std::cout<<"D is incorrect"< 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_D.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.reference_D.host_view()), 0); + file << "\nComputed abs_max_Aux :"; + file << " " << float(abs_max_Aux.at(coord_0)); + file << "\n\n"; } - bool passed = equality_check(impl_.reference_D.host_view(), impl_.tensor_D.host_view()); - if constexpr (IsAbsMaxEnabledD) { - abs_max_D.sync_host(); - passed &= equality_check(reference_abs_max_D.host_view(), abs_max_D.host_view()); + if constexpr (IsBiasEnabled) { + file << "\n\nBias = \n" << bias.host_view(); + } + + if constexpr (IsAuxInEnabled) { + file << "\n\nAux Input = \n" << tensor_Aux.host_view(); } if constexpr (IsDeBiasEnabled) { - bias.sync_host(); - EXPECT_GT(cutlass::reference::host::TensorNorm(bias.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_dbias.host_view()), 0); - passed &= equality_check(reference_dbias.host_view(), bias.host_view()); + file << "\n\nReference dBias = \n" << reference_dbias.host_view(); + file << "\n\nComputed dBias = \n" << bias.host_view(); } if constexpr (IsAuxOutEnabled) { - tensor_Aux.sync_host(); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0); - passed &= equality_check(reference_Aux.host_view(), tensor_Aux.host_view()); - if constexpr (IsAbsMaxEnabledAux) { - abs_max_Aux.sync_host(); - passed &= equality_check(reference_abs_max_Aux.host_view(), abs_max_Aux.host_view()); - } + file + << "\n\nReference Aux =\n" << reference_Aux.host_view() + << "\n\nComputed Aux =\n" << tensor_Aux.host_view(); } + file + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\n\nComputed =\n" << tensor_D.host_view(); - EXPECT_TRUE(passed); - if (!passed) { - std::stringstream fname; - fname << "error_Gemm_device_" - << M << "x" << N << "x" << K << "x" << L << "_" - << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" - << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" - << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + } - std::ofstream file(fname.str()); - file - << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L; - if constexpr (IsScaleFactorEnabled) { - file - << ", scale_a: " << scale_A.at(coord_0) - << ", scale_b: " << scale_B.at(coord_0) - << ", scale_c: " << scale_C.at(coord_0); - } - if constexpr (IsPerRowScaleEnabled) { - file << "\n\nvalpha = \n" << alpha.host_view(); - file << "\n\nvbeta = \n" << beta.host_view(); - } else { - file - << ", alpha: " << alpha.at(coord_0) << ", beta: " << beta.at(coord_0); - } - file << "\n\n"; + Arguments to_args(ProblemShapeType problem_size) { + auto coord_0 = cutlass::make_Coord(0); + Arguments arguments = + { + {}, + tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d + }; - if constexpr (IsAbsMaxEnabledD) { - file << "scale_d: " << float(scale_D.at(coord_0)); - file << "\nReference abs_max_D :"; - file << " " << float(reference_abs_max_D.at(coord_0)); + auto &fusion_args = arguments.thread; + if constexpr (IsLegacy) { + arguments.thread = { + alpha.at(coord_0), + beta.at(coord_0), + alpha.device_data(), + beta.device_data() + }; + arguments.ptr_Bias = bias.device_data(); + arguments.ptr_T = tensor_Aux.device_data(); + } + else { + fusion_args.alpha = alpha.at(coord_0); + fusion_args.beta = beta.at(coord_0); + fusion_args.alpha_ptr = alpha.device_data(); + fusion_args.beta_ptr = beta.device_data(); // if disable_vector_beta is true this is nullptr - file << "\nComputed abs_max_D :"; - file << " " << float(abs_max_D.at(coord_0)); - file << "\n\n"; + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_a = scale_A.at(coord_0); + fusion_args.scale_b = scale_B.at(coord_0); + fusion_args.scale_c = scale_C.at(coord_0); + fusion_args.scale_d = scale_D.at(coord_0); + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + fusion_args.scale_d_ptr = scale_D.device_data(); } - if constexpr (IsAbsMaxEnabledAux) { - file << "scale_aux: " << float(scale_Aux.at(coord_0)); - file << "\nReference abs_max_Aux :"; - file << " " << float(reference_abs_max_Aux.at(coord_0)); + if constexpr (IsBiasEnabled) { + fusion_args.bias_ptr = bias.device_data(); + } - file << "\nComputed abs_max_Aux :"; - file << " " << float(abs_max_Aux.at(coord_0)); - file << "\n\n"; + if constexpr (IsDeBiasEnabled) { + fusion_args.dbias_ptr = bias.device_data(); } - file - << "A =\n" << impl_.tensor_A.host_view() - << "\nB =\n" << impl_.tensor_B.host_view() - << "\nC =\n" << impl_.tensor_C.host_view(); + // example of how to set kernel activation arguments + // see ActivationFunctor::Arguments in activation.h for definition + // if Arguments doesn't exist then fusion_args.activation is empty + if constexpr (cute::is_same_v>) { + fusion_args.activation.scale = ElementCompute(1); + } - if constexpr (IsBiasEnabled) { - file << "\n\nBias = \n" << bias.host_view(); + // Treat Clamp as ReLU + if constexpr (cute::is_same_v>) { + fusion_args.activation.lower_bound = 0; + fusion_args.activation.upper_bound = std::numeric_limits::max(); } - if constexpr (IsAuxInEnabled) { - file << "\n\nAux Input = \n" << tensor_Aux.host_view(); + if constexpr (IsAbsMaxEnabledD) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); } - if constexpr (IsDeBiasEnabled) { - file << "\n\nReference dBias = \n" << reference_dbias.host_view(); - file << "\n\nComputed dBias = \n" << bias.host_view(); + if constexpr (IsAuxInEnabled) { + fusion_args.aux_ptr = tensor_Aux.device_data(); + fusion_args.dAux = stride_Aux; } if constexpr (IsAuxOutEnabled) { - file - << "\n\nReference Aux =\n" << reference_Aux.host_view() - << "\n\nComputed Aux =\n" << tensor_Aux.host_view(); + fusion_args.aux_ptr = tensor_Aux.device_data(); + fusion_args.dAux = stride_Aux; + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_aux = scale_Aux.at(coord_0); + fusion_args.scale_aux_ptr = scale_Aux.device_data(); + } + if constexpr (IsAbsMaxEnabledAux) { + fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); + } } - file - << "\n\nReference D =\n" << impl_.reference_D.host_view() - << "\n\nComputed D =\n" << impl_.tensor_D.host_view(); + } - return passed; + return arguments; } - /// Verifies the result against a reference implementation - bool verify(ProblemShapeType problem_size) - { + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto M = cute::get<0>(problem_shape_MNKL); auto N = cute::get<1>(problem_shape_MNKL); auto K = cute::get<2>(problem_shape_MNKL); auto L = cute::get<3>(problem_shape_MNKL); auto coord_0 = cutlass::make_Coord(0); - - auto A = cute::make_tensor(detail::make_iterator(impl_.tensor_A.host_data()), - cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a)); - auto B = cute::make_tensor(detail::make_iterator(impl_.tensor_B.host_data()), - cute::make_layout(cute::make_shape(N, K, L), impl_.stride_b)); - auto C = cute::make_tensor(detail::make_iterator(impl_.tensor_C.host_data()), - cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c)); - auto D = cute::make_tensor(detail::make_iterator(impl_.reference_D.host_data()), - cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d)); + auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), cute::make_layout(cute::make_shape(M, cute::_1{}))); auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensor_Aux.host_data() : reference_Aux.host_data()), @@ -1163,9 +1108,6 @@ struct Testbed3xFusionOperation { cute::make_layout(cute::make_shape(M, cute::_1{}))); auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), cute::make_layout(cute::make_shape(M, cute::_1{}))); - - cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; - cutlass::reference::host::GettEpilogueParams< ElementScalar, ElementScalar, @@ -1177,8 +1119,8 @@ struct Testbed3xFusionOperation { decltype(Aux), decltype(Valpha), decltype(Vbeta), - ActivationFunctor> - epilogue_params{}; + ActivationFunctor + > epilogue_params{}; epilogue_params.C = C; epilogue_params.D = D; @@ -1216,156 +1158,262 @@ struct Testbed3xFusionOperation { if constexpr (IsPerRowScaleEnabled) { epilogue_params.Valpha = Valpha; - if (not disable_vector_beta) { + if (disable_vector_beta == VectorBeta::ENABLED) { epilogue_params.Vbeta = Vbeta; } } + return epilogue_params; + } +}; + +template < + typename Gemm, + template class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false +> +struct TestbedImpl { + // Kernel data types + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type + using HostCollectiveMainloopType = HostCollectiveMainloop; + using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, + HostCollectiveDefaultEpilogue, + HostCollectiveEpilogue>; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using LayoutTagA = typename HostCollectiveMainloopType::LayoutTagA; + using LayoutTagB = typename HostCollectiveMainloopType::LayoutTagB; + using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; + using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; + + uint32_t sm_count; + // Used to force multi-wave tests for persistent kernel schedules + constexpr static int MaxSmCount = 16; + static constexpr uint64_t kDefaultSeed = 4096; + static constexpr uint32_t mma_promotion_interval = 4; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + HostCollectiveMainloopType collective_mma_inputs; + CollectiveEpilogue collective_epilogue; + + // + // Methods + // + + TestbedImpl( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, disable_vector_beta_, init_C_, init_scale_, init_bias_, seed_)) { } + + TestbedImpl( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(stride_factor_A_, stride_factor_B_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, disable_vector_beta_, init_C_, init_scale_, init_bias_, seed_)) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + collective_mma_inputs.initialize(problem_size); + collective_epilogue.initialize(problem_size, alpha_, beta_); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) + { + auto [M, N, K, L] = problem_shape_MNKL; + + bool passed = collective_mma_inputs.compare_reference(problem_shape_MNKL); + passed &= collective_epilogue.compare_reference(problem_shape_MNKL, alpha, beta); + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + collective_mma_inputs.print_tensors(file); + collective_epilogue.print_tensors(file); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + ProblemShapeType problem_size, + ElementScalar alpha, + ElementScalar beta) + { + using namespace cute; + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto mainloop_params = collective_mma_inputs.to_host_args(problem_size); + auto epilogue_params = collective_epilogue.to_host_args(problem_size); + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + bool passed = compare_reference(problem_shape_MNKL, alpha, beta); + return passed; + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() { + // + // Determine SMEM requirements and waive if not satisfied + // + + int smem_size = Gemm::GemmKernel::SharedStorageSize; + + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + this->sm_count = properties.multiProcessorCount; + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } + + return true; + } + + bool profile( + ProblemShapeType problem_size, + int iterations, + Gemm& gemm_op, + typename Gemm::Arguments& arguments, + cutlass::device_memory::allocation& workspace) { + int M = cute::size<0>(problem_size); + int N = cute::size<1>(problem_size); + int K = cute::size<2>(problem_size); + int L = 1; + if constexpr(cute::rank(ProblemShapeType{}) == 4) { + L = cute::size<3>(problem_size); + } + + + cutlass::Status status; + // + // Run the GEMM + // + cudaError_t result; + + for (int iter = 0; iter < iterations; ++iter) { + status = gemm_op(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + return false; + } + } - cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } - return compare_reference(problem_shape_MNKL); + return true; } /// Executes one test bool run( ProblemShapeType problem_size, - ElementScalar alpha_ = ElementScalar(1), - ElementScalar beta_ = ElementScalar(0), + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + detail::Iterations iterations = detail::Iterations{}, RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, detail::Splits splits = detail::Splits{}, - DecompositionMode decomposition_mode = DecompositionMode::Heuristic, - bool profiling = false, - detail::Iterations iterations = detail::Iterations{}) + DecompositionMode decomposition_mode = DecompositionMode::Heuristic) { + // Fail test if insufficient CUDA device - if (!impl_.sufficient()) { + if (!sufficient()) { std::cout << "Test failed due to insufficient CUDA device." << std::endl; return false; } + + this->initialize(problem_size, alpha, beta); + // // Initialize the GEMM operator // typename Gemm::Arguments arguments; cutlass::KernelHardwareInfo hw_info; - cudaDeviceProp prop; - hw_info.device_id = 0; if (not profiling) { - impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); - hw_info.sm_count = impl_.sm_count; + this->sm_count = std::min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = this->sm_count; } else { - impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - hw_info.sm_count = impl_.sm_count; + this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = this->sm_count; } - cudaGetDeviceProperties(&prop, hw_info.device_id); - - /// Initializes data structures - /// A/B/C/D Tensor - initialize(problem_size, alpha_, beta_); - - arguments = typename Gemm::Arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - { - impl_.tensor_A.device_data(), impl_.stride_a, - impl_.tensor_B.device_data(), impl_.stride_b - }, - { // Epilogue arguments - {}, // thread - impl_.tensor_C.device_data(), - impl_.stride_c, - impl_.tensor_D.device_data(), - impl_.stride_d - }, // Epilogue arguments end - hw_info - }; - + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; if constexpr (std::is_same_v) { - arguments.scheduler.splits = static_cast(splits); - arguments.scheduler.max_swizzle_size = static_cast(max_swizzle); - arguments.scheduler.raster_order = raster_order; - arguments.scheduler.decomposition_mode = decomposition_mode; - } else { - arguments.scheduler.max_swizzle_size = static_cast(max_swizzle); - arguments.scheduler.raster_order = raster_order; - } - - auto coord_0 = cutlass::make_Coord(0); - if constexpr (IsLegacy) { - arguments.epilogue.thread = { - alpha.at(coord_0), - beta.at(coord_0), - alpha.device_data(), - beta.device_data() - }; - arguments.epilogue.ptr_Bias = bias.device_data(); - arguments.epilogue.ptr_T = tensor_Aux.device_data(); + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; } else { - auto &fusion_args = arguments.epilogue.thread; - - fusion_args.alpha = alpha.at(coord_0); - fusion_args.beta = beta.at(coord_0); - fusion_args.alpha_ptr = alpha.device_data(); - fusion_args.beta_ptr = beta.device_data(); // if disable_vector_beta is true this is nullptr - - if constexpr (IsScaleFactorEnabled) { - fusion_args.scale_a = scale_A.at(coord_0); - fusion_args.scale_b = scale_B.at(coord_0); - fusion_args.scale_c = scale_C.at(coord_0); - fusion_args.scale_d = scale_D.at(coord_0); - fusion_args.scale_a_ptr = scale_A.device_data(); - fusion_args.scale_b_ptr = scale_B.device_data(); - fusion_args.scale_c_ptr = scale_C.device_data(); - fusion_args.scale_d_ptr = scale_D.device_data(); - } - - if constexpr (IsBiasEnabled) { - fusion_args.bias_ptr = bias.device_data(); - } - - if constexpr (IsDeBiasEnabled) { - fusion_args.dbias_ptr = bias.device_data(); - } - - // example of how to set kernel activation arguments - // see ActivationFunctor::Arguments in activation.h for definition - // if Arguments doesn't exist then fusion_args.activation is empty - if constexpr (cute::is_same_v>) { - fusion_args.activation.scale = ElementCompute(1); - } - - // Treat Clamp as ReLU - if constexpr (cute::is_same_v>) { - fusion_args.activation.lower_bound = 0; - fusion_args.activation.upper_bound = std::numeric_limits::max(); - } - - if constexpr (IsAbsMaxEnabledD) { - fusion_args.amax_D_ptr = abs_max_D.device_data(); - } - - if constexpr (IsAuxInEnabled) { - fusion_args.aux_ptr = tensor_Aux.device_data(); - fusion_args.dAux = stride_Aux; - } - - if constexpr (IsAuxOutEnabled) { - fusion_args.aux_ptr = tensor_Aux.device_data(); - fusion_args.dAux = stride_Aux; - if constexpr (IsScaleFactorEnabled) { - fusion_args.scale_aux = scale_Aux.at(coord_0); - fusion_args.scale_aux_ptr = scale_Aux.device_data(); - } - if constexpr (IsAbsMaxEnabledAux) { - fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); - } - } + scheduler_args = { static_cast(max_swizzle), raster_order }; } + arguments = { + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + collective_mma_inputs.to_args(), + collective_epilogue.to_args(problem_size), + hw_info, + scheduler_args + }; + Gemm gemm_op; @@ -1385,7 +1433,7 @@ struct Testbed3xFusionOperation { // if (profiling) { - return impl_.profile(problem_size, static_cast(iterations), gemm_op, arguments, workspace); + return profile(problem_size, static_cast(iterations), gemm_op, arguments, workspace); } else { cudaError_t result; @@ -1402,27 +1450,134 @@ struct Testbed3xFusionOperation { // // Verify // - bool passed = this->verify(problem_size); + bool passed = this->verify(problem_size, alpha, beta); if (!passed) { - std::cout << "Error : Failed : with alpha: " << float(alpha_) << ", beta: " << float(beta_) + std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta << "\n"; } + return passed; } } }; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false +> +struct Testbed3x { + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementAccumulator = typename TestBedImpl::ElementAccumulator; + using ElementCompute = typename TestBedImpl::ElementCompute; + using ElementScalar = typename TestBedImpl::ElementScalar; + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3x( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_DEVICE, + VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed) + : impl_(check_relative_equality_, use_device_scalars_, disable_vector_beta_, init_A_, init_B_, init_C_, init_scale_, init_bias_, seed_) {} + + /// Executes one test + bool run( + typename TestBedImpl::ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, + detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic, + bool profiling = false, + detail::Iterations iterations = detail::Iterations{}) + { + return impl_.run( + problem_size, alpha, beta, profiling, iterations, raster_order, max_swizzle, splits, decomposition_mode + ); + } +}; ///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmPerf3x(int iterations = 20) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalar = ElementAccumulator; + bool passed = true; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector problem_size_m = { 4608 }; + std::vector problem_size_n = { 4608 }; + std::vector problem_size_k = { 8192 }; + + Testbed3x testbed; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(0), + RasterOrderOptions{}, detail::MaxSwizzleSize(1), detail::Splits{1}, DecompositionMode{}, + true, // profiling + detail::Iterations{iterations}); + + if (!passed) { + return false; + } + } + } + } + + return true; +} + template < typename Gemm, - typename Testbed = Testbed3x + template class ActivationFunctor = cutlass::epilogue::thread::Identity > -bool TestAll(double alpha = 1.0, double beta = 0.0, Testbed testbed = {}) { +bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + Testbed3x testbed(check_relative_equality, ScalarLoc::ON_HOST, VectorBeta::DISABLED); + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; @@ -1440,11 +1595,11 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, Testbed testbed = {}) { using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; std::vector decomposition_modes = {DecompositionMode::Heuristic}; - std::vector problem_splits = {1}; + std::vector problem_splits = {detail::Splits{1}}; static constexpr bool UsesStreamKScheduler = std::is_same_v; if constexpr (UsesStreamKScheduler) { - problem_splits.push_back(2); - problem_splits.push_back(3); + problem_splits.push_back(detail::Splits{2}); + problem_splits.push_back(detail::Splits{3}); decomposition_modes.push_back(DecompositionMode::DataParallel); decomposition_modes.push_back(DecompositionMode::SplitK); @@ -1457,7 +1612,7 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, Testbed testbed = {}) { using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; std::vector raster_orders = {RasterOrderOptions::AlongM, RasterOrderOptions::AlongN}; - std::vector max_swizzle_sizes = {1, 4}; + std::vector max_swizzle_sizes{detail::MaxSwizzleSize{1}, detail::MaxSwizzleSize{4}}; bool passed = true; @@ -1465,26 +1620,26 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, Testbed testbed = {}) { for (int n : problem_size_n) { for (int k : problem_size_k) { for (auto raster_order : raster_orders) { - for (int max_swizzle_size : max_swizzle_sizes) { + for (auto max_swizzle_size : max_swizzle_sizes) { for (DecompositionMode decomp_mode : decomposition_modes) { - std::vector problem_splits = {1}; - if (UsesStreamKScheduler && (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK)) { + std::vector problem_splits = {detail::Splits{1}}; + if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { auto max_splits = (k + TileShapeK - 1) / TileShapeK; if (max_splits > 2) { - problem_splits.push_back(2); + problem_splits.push_back(detail::Splits{2}); } if (max_splits > 3) { - problem_splits.push_back(3); + problem_splits.push_back(detail::Splits{3}); } - problem_splits.push_back(max_splits); + problem_splits.push_back(detail::Splits{max_splits}); // Test the case in which we ask for more splits than there are K tiles in the GEMM. In this // case, split-K will fall back to a splitting factor of `max_splits`. - problem_splits.push_back(max_splits + 1); + problem_splits.push_back(detail::Splits{max_splits + 1}); } - for (int splits : problem_splits) { + for (auto splits : problem_splits) { ProblemShapeType problem_size; if constexpr (cute::rank(ProblemShapeType{}) == 4) { problem_size = ProblemShapeType{m, n, k, /* l */ 1}; @@ -1498,12 +1653,13 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, Testbed testbed = {}) { cutlass::from_real(alpha), cutlass::from_real(beta), raster_order, - detail::MaxSwizzleSize(max_swizzle_size), - detail::Splits(splits), + max_swizzle_size, + splits, decomp_mode ); if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; return false; } } // splits @@ -1531,77 +1687,11 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, Testbed testbed = {}) { return passed; } -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -bool TestAllBiasElementwise(double alpha = 1.0, double beta = 0.0, bool check_relative_equality=false) { - Testbed3xFusionOperation testbed(check_relative_equality); - - return TestAll(alpha, beta, testbed); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - template -bool TestGemmPerf3x(int iterations = 20) { - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; - using ElementScalar = ElementAccumulator; - bool passed = true; - - std::vector problem_size_m = { 4608 }; - std::vector problem_size_n = { 4608 }; - std::vector problem_size_k = { 8192 }; - - Testbed3x testbed; - - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - ProblemShapeType problem_size; - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - problem_size = ProblemShapeType{m, n, k, /* l */ 1}; - } - else { - problem_size = ProblemShapeType{m, n, k}; - } - - passed = testbed.run( - problem_size, - cutlass::from_real(1), - cutlass::from_real(0), - true, - detail::Iterations(iterations) - ); - - if (!passed) { - return false; - } - } - } - } - - - // if we do support batched GEMM, just run it once - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - auto problem_size = ProblemShapeType{problem_size_m[0], problem_size_n[0], problem_size_k[0], /* l */ 4}; - passed = testbed.run( - problem_size, - cutlass::from_real(1), - cutlass::from_real(0), - true, - detail::Iterations(iterations) - ); - - if (!passed) { - return false; - } - } - - return passed; +bool TestAllBiasElementwise(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative_equality = CheckEquality::EXACT) { + return TestAll(alpha, beta, check_relative_equality); } - } // namespace device } // namespace gemm } // namespace test diff --git a/test/unit/gemm/device/gemm_testbed_3x_evt.hpp b/test/unit/gemm/device/gemm_testbed_3x_evt.hpp index 90034d07..3a5d818e 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_evt.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_evt.hpp @@ -58,7 +58,7 @@ template < class HostEVTNodeBase { public: using Gemm = Gemm_; - using TestBedImpl = typename detail::TestbedImpl; + using TestBedImpl = typename detail::TestbedImpl; using Kernel = typename Gemm::GemmKernel; using Epilogue = typename Kernel::CollectiveEpilogue; using ElementCompute = typename TestBedImpl::ElementCompute; @@ -238,9 +238,9 @@ class HostRowBroadcast: public HostEVTNodeBase { _bias.resize(cutlass::Coord<1>(_N)); EXPECT_TRUE( - impl_.initialize_tensor( + detail::initialize_tensor( _bias.host_view(), cutlass::Distribution::Uniform, - impl_.seed + 2023 + impl_.collective_mma_inputs.seed + 2023 ) ); _bias.sync_device(); @@ -306,9 +306,9 @@ class HostColBroadcast: public HostEVTNodeBase { _bias.resize(cutlass::Coord<1>(_M)); EXPECT_TRUE( - impl_.initialize_tensor( + detail::initialize_tensor( _bias.host_view(), cutlass::Distribution::Uniform, - impl_.seed + 2023 + impl_.collective_mma_inputs.seed + 2023 ) ); _bias.sync_device(); @@ -393,10 +393,10 @@ class HostAuxLoad: public HostEVTNodeBase { ) ); EXPECT_TRUE( - impl_.initialize_tensor( + detail::initialize_tensor( _tensor_aux_load.host_view(), cutlass::Distribution::Uniform, - impl_.seed + 2023 + impl_.collective_mma_inputs.seed + 2023 ) ); _tensor_aux_load.sync_device(); @@ -1154,7 +1154,7 @@ class Testbed3xEVT { // The EVT Module to test using EVTModule = typename EVT::EVTModule; - using TestBedImpl = typename detail::TestbedImpl; + using TestBedImpl = typename detail::TestbedImpl; using Kernel = typename Gemm::GemmKernel; using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; using ElementAccumulator = typename Kernel::ElementAccumulator; @@ -1178,7 +1178,9 @@ class Testbed3xEVT { cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, uint64_t seed_ = TestBedImpl::kDefaultSeed ) : - impl_(init_A_, init_B_, init_C_, seed_), check_relative_equality(check_relative_equality_) { } + impl_((check_relative_equality_ ? CheckEquality::RELATIVE : CheckEquality::EXACT), ScalarLoc::ON_DEVICE, VectorBeta::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(check_relative_equality_) { } Testbed3xEVT( cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, @@ -1186,7 +1188,9 @@ class Testbed3xEVT { cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, uint64_t seed_ = TestBedImpl::kDefaultSeed ) : - impl_(init_A_, init_B_, init_C_, seed_), check_relative_equality(false) { } + impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorBeta::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(false) { } Testbed3xEVT( typename LayoutTagA::Stride stride_factor_A_, @@ -1198,15 +1202,10 @@ class Testbed3xEVT { cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, uint64_t seed_ = TestBedImpl::kDefaultSeed ) : - impl_(stride_factor_A_, - stride_factor_B_, - stride_factor_C_, - stride_factor_D_, - init_A_, - init_B_, - init_C_, - seed_), - check_relative_equality(false) { } + impl_(stride_factor_A_, stride_factor_B_, stride_factor_C_, stride_factor_D_, + CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorBeta::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(false) { } /// Initializes data structures void initialize(ProblemShapeType problem_size) { @@ -1229,11 +1228,11 @@ class Testbed3xEVT { auto K = cute::get<2>(problem_shape_MNKL); auto L = cute::get<3>(problem_shape_MNKL); - auto A = cute::make_tensor(impl_.tensor_A.host_data(), - cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a)); - auto B = cute::make_tensor(impl_.tensor_B.host_data(), - cute::make_layout(cute::make_shape(N, K, L), impl_.stride_b)); - auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d); + auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); + auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); + auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d); cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; @@ -1277,9 +1276,9 @@ class Testbed3xEVT { << ", Batch count = " << L << "\n\n"; file - << "A =\n" << impl_.tensor_A.host_view() - << "\nB =\n" << impl_.tensor_B.host_view() - << "\nC =\n" << impl_.tensor_C.host_view() << "\n\n"; + << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() + << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view() + << "\nC =\n" << impl_.collective_epilogue.tensor_C.host_view() << "\n\n"; file << error_ss.str(); } @@ -1329,15 +1328,15 @@ class Testbed3xEVT { cutlass::gemm::GemmUniversalMode::kGemm, problem_size, { - impl_.tensor_A.device_data(), impl_.stride_a, - impl_.tensor_B.device_data(), impl_.stride_b + impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, + impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b }, { // Epilogue arguments {}, // thread static_cast(host_reference.get_tensor_C_ptr()), - impl_.stride_c, + impl_.collective_epilogue.stride_c, static_cast(host_reference.get_tensor_D_ptr()), - impl_.stride_d + impl_.collective_epilogue.stride_d }, // Epilogue arguments end hw_info, scheduler_args diff --git a/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp b/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp index 0cabaa72..70b64468 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp @@ -101,7 +101,8 @@ struct Testbed3xTensorBroadcast { cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, uint64_t seed_ = TestBedImpl::kDefaultSeed ) : - impl_(init_A_, init_B_, init_C_, seed_) { } + impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorBeta::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_) { } Testbed3xTensorBroadcast( typename LayoutTagA::Stride stride_factor_A_, @@ -117,9 +118,12 @@ struct Testbed3xTensorBroadcast { stride_factor_B_, stride_factor_C_, stride_factor_D_, + CheckEquality::EXACT, ScalarLoc::ON_HOST, VectorBeta::ENABLED, init_A_, init_B_, init_C_, + cutlass::Distribution::Uniform, + cutlass::Distribution::Uniform, seed_) { } /// Initializes data structures @@ -135,7 +139,7 @@ struct Testbed3xTensorBroadcast { auto bias_size = PerColBias ? cute::get<1>(problem_shape_MNKL) : cute::get<0>(problem_shape_MNKL); bias.resize(cutlass::Coord<1>(bias_size)); - EXPECT_TRUE(impl_.initialize_tensor(bias.host_view(), cutlass::Distribution::Uniform, impl_.seed + 2023)); + EXPECT_TRUE(detail::initialize_tensor(bias.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2023)); bias.sync_device(); } @@ -147,8 +151,8 @@ struct Testbed3xTensorBroadcast { auto c_coord = cutlass::make_Coord(M * L, N); - tensor_C1.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.stride_factor_C)); - EXPECT_TRUE(impl_.initialize_tensor(tensor_C1.host_view(), cutlass::Distribution::Uniform, impl_.seed + 2024)); + tensor_C1.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.collective_epilogue.stride_factor_C)); + EXPECT_TRUE(detail::initialize_tensor(tensor_C1.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2024)); tensor_C1.sync_device(); } @@ -161,19 +165,19 @@ struct Testbed3xTensorBroadcast { { auto [M, N, K, L] = problem_shape_MNKL; - impl_.tensor_D.sync_host(); - EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_B.host_view()), 0); + impl_.collective_epilogue.tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_mma_inputs.tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_mma_inputs.tensor_B.host_view()), 0); - if (impl_.tensor_D.size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_D.host_view()), 0); + if (impl_.collective_epilogue.tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_epilogue.tensor_D.host_view()), 0); } - if (impl_.reference_D.size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.reference_D.host_view()), 0); + if (impl_.collective_epilogue.reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_epilogue.reference_D.host_view()), 0); } - bool passed = cutlass::reference::host::TensorEquals(impl_.reference_D.host_view(), impl_.tensor_D.host_view()); + bool passed = cutlass::reference::host::TensorEquals(impl_.collective_epilogue.reference_D.host_view(), impl_.collective_epilogue.tensor_D.host_view()); EXPECT_TRUE(passed); @@ -196,12 +200,12 @@ struct Testbed3xTensorBroadcast { } file - << "A =\n" << impl_.tensor_A.host_view() - << "\nB =\n" << impl_.tensor_B.host_view() - << "\nC0 =\n" << impl_.tensor_C.host_view() + << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() + << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view() + << "\nC0 =\n" << impl_.collective_epilogue.tensor_C.host_view() << "\nC1 =\n" << tensor_C1.host_view() - << "\n\nReference =\n" << impl_.reference_D.host_view() - << "\n\nComputed =\n" <(problem_shape_MNKL); auto L = cute::get<3>(problem_shape_MNKL); - auto A = cute::make_tensor(impl_.tensor_A.host_data(), - cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a)); - auto B = cute::make_tensor(impl_.tensor_B.host_data(), - cute::make_layout(cute::make_shape(N, K, L), impl_.stride_b)); - auto D = cute::make_tensor(impl_.reference_D.host_data(), - cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d)); + auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); + auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); + auto D = cute::make_tensor(impl_.collective_epilogue.reference_D.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d)); auto Bias = cute::make_tensor(static_cast(use_bias ? bias.host_data() : nullptr), cute::make_layout(PerColBias ? cute::make_shape(1, N) : cute::make_shape(M, 1))); - auto C0 = cute::make_tensor(impl_.tensor_C.host_data(), - cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c)); + auto C0 = cute::make_tensor(impl_.collective_epilogue.tensor_C.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); auto C1 = cute::make_tensor(tensor_C1.host_data(), - cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c)); + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); // Create host workspace for output of testbed. This computes a portion of the epilogue: // ref_compute_out = Activation(alpha * (A @ B) + bias) cutlass::HostTensor ref_compute_out; auto c_coord = cutlass::make_Coord(M * L, N); - ref_compute_out.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.stride_factor_C), false); + ref_compute_out.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.collective_epilogue.stride_factor_C), false); auto RefComputeOut = cute::make_tensor(ref_compute_out.host_data(), - cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c)); + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; // Use a dummy null tensor for operand C because the epilogue overrides C. auto dummy_C = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c)); + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); ElementCompute dummy_beta(0); auto dummy_Aux = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d)); + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d)); auto dummy_Valpha = cute::make_tensor(static_cast(nullptr), cute::make_layout(cute::make_shape(M, 1))); auto dummy_Vbeta = cute::make_tensor(static_cast(nullptr), cute::make_layout(cute::make_shape(M, 1))); - cutlass::reference::host::GettEpilogueParams< ElementScalar, ElementScalar, @@ -361,17 +364,17 @@ struct Testbed3xTensorBroadcast { arguments = typename Gemm::Arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - { impl_.tensor_A.device_data(), impl_.stride_a, - impl_.tensor_B.device_data(), impl_.stride_b, + { impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, + impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b, impl_.mma_promotion_interval }, { // Epilogue arguments { alpha, beta }, // ThreadOp arguments - impl_.stride_c, - impl_.tensor_D.device_data(), - impl_.stride_d, + impl_.collective_epilogue.stride_c, + impl_.collective_epilogue.tensor_D.device_data(), + impl_.collective_epilogue.stride_d, use_bias ? bias.device_data() : nullptr, - impl_.tensor_C.device_data(), + impl_.collective_epilogue.tensor_C.device_data(), tensor_C1.device_data() }, // Epilogue arguments end hw_info diff --git a/test/unit/gemm/device/sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu index 14654c78..0c542d6d 100644 --- a/test/unit/gemm/device/sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu +++ b/test/unit/gemm/device/sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu @@ -112,6 +112,8 @@ TEST(SM80_Device_Gemm_tf32t_tf32n_f32n_tensor_op_f32, 128x128x32_64x64x64) { EXPECT_TRUE(test::gemm::device::TestAll()); } +///////////////////////////////////////////////////////////////////////////////////////////////// + TEST(SM80_Device_Gemm_tf32t_tf32t_f32n_tensor_op_f32, 128x128x32_64x64x64) { using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -132,4 +134,24 @@ TEST(SM80_Device_Gemm_tf32t_tf32t_f32n_tensor_op_f32, 128x128x32_64x64x64) { ///////////////////////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Gemm_tf32t_tf32n_f32n_tensor_op_f32, 128x128x32_64x64x64_profiling) { + using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::tfloat32_t, cutlass::layout::RowMajor, + cutlass::tfloat32_t, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + float>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Config::CollectiveMainloop, + Config::CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestGemmPerf3x()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + //#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu index 408d9b31..a00cb5eb 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu @@ -97,9 +97,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - test::gemm::device::Testbed3x testbed; - bool passed = test::gemm::device::TestAll(1, 1, testbed); + bool passed = test::gemm::device::TestAll(1, 1); EXPECT_TRUE(passed); } @@ -156,6 +154,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 #pragma GCC diagnostic pop // Re-enable deprecation warnings } + TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -239,9 +238,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - bool check_relative_equality = true; - bool passed = test::gemm::device::TestAllBiasElementwise(1, 1, check_relative_equality); + using namespace test::gemm::device; + bool passed = TestAllBiasElementwise(1, 1, CheckEquality::RELATIVE); EXPECT_TRUE(passed); } @@ -600,8 +598,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - bool passed = test::gemm::device::TestAllBiasElementwise(1.0, 0.0, /*check_relative_equality=*/true); + using namespace test::gemm::device; + bool passed = TestAllBiasElementwise(1.0, 0.0, CheckEquality::RELATIVE); EXPECT_TRUE(passed); } diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu index 83f03e6d..96765250 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu @@ -97,8 +97,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - test::gemm::device::Testbed3x testbed; - bool passed = test::gemm::device::TestAll(1, 1, testbed); + bool passed = test::gemm::device::TestAll(1, 1); EXPECT_TRUE(passed); } @@ -186,8 +185,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - bool check_relative_equality = true; - bool passed = test::gemm::device::TestAllBiasElementwise(1, 1, check_relative_equality); + using namespace test::gemm::device; + bool passed = TestAllBiasElementwise(1, 1, CheckEquality::RELATIVE); EXPECT_TRUE(passed); } diff --git a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu index e4b92ff9..e455d8a9 100644 --- a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu +++ b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu @@ -1,24 +1,30 @@ /*************************************************************************************************** - * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ diff --git a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32_tensor_broadcast.cu b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32_tensor_broadcast.cu index 99370aa0..575f7a7d 100644 --- a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32_tensor_broadcast.cu +++ b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32_tensor_broadcast.cu @@ -1,24 +1,30 @@ /*************************************************************************************************** - * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ diff --git a/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32_tensor_broadcast.cu b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32_tensor_broadcast.cu index a1f352d6..864ee38c 100644 --- a/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32_tensor_broadcast.cu +++ b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32_tensor_broadcast.cu @@ -1,24 +1,30 @@ /*************************************************************************************************** - * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ diff --git a/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_f32_sm80.cu index 7090a0a6..e2df959f 100644 --- a/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_f32_sm80.cu +++ b/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_f32_sm80.cu @@ -50,6 +50,7 @@ #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// +#if (!((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ == 8))) TEST(SM80_Device_Syr2k_cf32n_cf32t_l_tensor_op_f32, 64x64x16_32x32x16) { @@ -145,6 +146,7 @@ TEST(SM80_Device_Syr2k_cf32n_cf32t_u_tensor_op_f32, 64x64x16_32x32x16) { EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); } +#endif ///////////////////////////////////////////////////////////////////////////////////////////////// #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_fast_f32_sm80.cu b/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_fast_f32_sm80.cu index 0c6efb1e..0c5c25d0 100644 --- a/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_fast_f32_sm80.cu +++ b/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_fast_f32_sm80.cu @@ -50,6 +50,7 @@ #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// +#if (!((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ == 8))) TEST(SM80_Device_Syr2k_cf32n_cf32t_l_tensor_op_fast_f32, 64x64x16_32x32x16) { @@ -145,6 +146,7 @@ TEST(SM80_Device_Syr2k_cf32n_cf32t_u_tensor_op_fast_f32, 64x64x16_32x32x16) { EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); } +#endif ///////////////////////////////////////////////////////////////////////////////////////////////// #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_sm80.cu b/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_sm80.cu index 3f7b03ac..5f13ef4b 100644 --- a/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_sm80.cu +++ b/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_sm80.cu @@ -50,6 +50,7 @@ #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// +#if (!((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ == 8))) TEST(SM80_Device_Syr2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { @@ -145,6 +146,7 @@ TEST(SM80_Device_Syr2k_cf64n_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); } +#endif ///////////////////////////////////////////////////////////////////////////////////////////////// #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/nvrtc/stdlib/assert.h b/test/unit/nvrtc/stdlib/assert.h index e69de29b..efc3225a 100644 --- a/test/unit/nvrtc/stdlib/assert.h +++ b/test/unit/nvrtc/stdlib/assert.h @@ -0,0 +1,30 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index 6a1aa6b5..641eac83 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -278,6 +278,7 @@ execute_process( --architectures "${CUTLASS_NVCC_ARCHS_ENABLED}" --kernels "${CUTLASS_LIBRARY_KERNELS}" --ignore-kernels "${CUTLASS_LIBRARY_IGNORE_KERNELS}" + --kernel-filter-file "${KERNEL_FILTER_FILE}" --selected-kernel-list "${CUTLASS_LIBRARY_GENERATED_KERNEL_LIST_FILE}" --cuda-version "${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}" --log-level DEBUG diff --git a/tools/profiler/include/cutlass/profiler/device_allocation.h b/tools/profiler/include/cutlass/profiler/device_allocation.h index b5b3ee4a..95e552b3 100644 --- a/tools/profiler/include/cutlass/profiler/device_allocation.h +++ b/tools/profiler/include/cutlass/profiler/device_allocation.h @@ -207,7 +207,10 @@ class DeviceAllocation { void initialize_random_sparsemeta_host(int seed, int MetaSizeInBits); /// Uniformly fills a tensor with a value when provided o.w. zero - void fill(double value); + void fill_device(double value); + + /// Uniformly fills a host allocation with a value when provided o.w. zero + void fill_host(double value); /// Copies from an equivalent-sized tensor in device memory void copy_from_device(void const *ptr); diff --git a/tools/profiler/src/device_allocation.cu b/tools/profiler/src/device_allocation.cu index 0419692f..873dc01d 100644 --- a/tools/profiler/src/device_allocation.cu +++ b/tools/profiler/src/device_allocation.cu @@ -2160,7 +2160,7 @@ static void tensor_fill(DeviceAllocation &allocation, Element val = Element()) { } /// Fills a tensor uniformly with a value (most frequently used to clear the tensor) -void DeviceAllocation::fill(double val = 0.0) { +void DeviceAllocation::fill_device(double val = 0.0) { switch (this->type()) { case library::NumericTypeID::kFE4M3: @@ -2259,6 +2259,180 @@ void DeviceAllocation::fill(double val = 0.0) { } } +/// Fills a tensor uniformly with a value (most frequently used to clear the tensor) +void DeviceAllocation::fill_host(double val = 0.0) { + + std::vector host_data(bytes()); + + switch (this->type()) { + case library::NumericTypeID::kFE4M3: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kFE5M2: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kF16: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kBF16: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kTF32: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kF32: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kF64: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kS2: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kS4: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kS8: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kS16: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kS32: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kS64: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kB1: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kU2: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kU4: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kU8: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kU16: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kU32: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kU64: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + default: + throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(this->type())); + } + + copy_from_host(host_data.data()); +} + + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace profiler diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index 0f7b19f9..e91577f2 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -77,6 +77,7 @@ struct GettMainloopParams { ComplexTransform transform_A = ComplexTransform::kNone; ComplexTransform transform_B = ComplexTransform::kNone; + }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -126,6 +127,7 @@ struct GettEpilogueParams { TensorAux Aux{}; VectorAlpha Valpha{}; VectorBeta Vbeta{}; + ElementCompute st = ElementCompute(1); ElementAccumulator* abs_max_D = nullptr; ElementAccumulator* abs_max_Aux = nullptr; @@ -204,6 +206,7 @@ void gett_mainloop( if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); + if (mainloop_params.transform_A == ComplexTransform::kConjugate) { a_frag[m_b] = conj(a_frag[m_b]); } @@ -218,6 +221,7 @@ void gett_mainloop( if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); + if (mainloop_params.transform_B == ComplexTransform::kConjugate) { b_frag[n_b] = conj(b_frag[n_b]); } @@ -325,6 +329,8 @@ void gett_epilogue( converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); converted_beta = mul(converted_beta, converted_scale_c); + ElementCompute inter_accum[kBlockM][kBlockN]; + for (int m_b = 0; m_b < kBlockM; ++m_b) { ElementCompute local_dBias = ElementCompute(0); @@ -391,7 +397,7 @@ void gett_epilogue( output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); } - epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(output); + inter_accum[m_b][n_b] = ElementCompute(output); } } // n_b @@ -403,6 +409,13 @@ void gett_epilogue( } } } // m_b + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); + } + } + } #if defined(_OPENMP) #pragma omp critical(Abs_Max_Data_Update) #endif diff --git a/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/tools/util/include/cutlass/util/reference/host/tensor_fill.h index 3d776e28..d6698f70 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -947,6 +947,20 @@ void TensorFillPadDiagonalRandomUniform( /////////////////////////////////////////////////////////////////////////////////////////////////// +/// Fills a tensor with a uniform value +template < + typename Element ///< Element type +> +void BlockFill( + Element *ptr, + size_t capacity, + Element val + ) { + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = val; + } +} + /// Fills a tensor with random values with a uniform random distribution. template < typename Element ///< Element type