From 6f2dd7c461abb1a3a7bf87f187d2251230624629 Mon Sep 17 00:00:00 2001 From: Babak Poursartip Date: Fri, 17 Nov 2023 10:27:42 -0600 Subject: [PATCH 1/2] ROCm 6.1 merge staging into master (#1818) --- .jenkins/common.groovy | 9 +- .jenkins/integration.groovy | 2 +- CHANGELOG.md | 40 +- HostLibraryTests/CMakeLists.txt | 2 +- HostLibraryTests/CachingLibrary_test.cpp | 4 +- .../ContractionSelectionLibrary_test.cpp | 8 +- .../ProjectedPerformance_test.cpp | 30 +- .../llvm/LibraryPerformance_test.cpp | 8 +- Tensile/Common.py | 64 +- Tensile/Components/Signature.py | 19 + Tensile/Contractions.py | 10 +- Tensile/Hardware.py | 53 +- Tensile/KernelWriter.py | 45 +- Tensile/KernelWriterAssembly.py | 2540 ++++++++++++++--- Tensile/KernelWriterConversion.py | 177 +- Tensile/KernelWriterSource.py | 6 + Tensile/KernelWriterStreamKInit.py | 121 + Tensile/LibraryIO.py | 16 +- Tensile/SolutionLibrary.py | 10 +- Tensile/SolutionStructs.py | 418 ++- Tensile/SolutionWriter.py | 4 +- Tensile/Source/TensileTypes.h | 9 +- Tensile/Source/client/main.cpp | 6 +- .../Source/client/source/HardwareMonitor.cpp | 10 +- .../Source/client/source/SolutionIterator.cpp | 1 + Tensile/Source/lib/CMakeLists.txt | 1 + Tensile/Source/lib/include/Tensile/AMDGPU.hpp | 10 +- .../lib/include/Tensile/AMDGPUPredicates.hpp | 28 +- .../include/Tensile/ContractionProblem.hpp | 11 + .../Tensile/ContractionProblemPredicates.hpp | 4 +- .../include/Tensile/ContractionSolution.hpp | 76 +- .../Tensile/ContractionSolution_fwd.hpp | 1 + Tensile/Source/lib/include/Tensile/Debug.hpp | 2 + .../include/Tensile/PlaceholderLibrary.hpp | 40 +- .../Serialization/ContractionSolution.hpp | 5 +- .../Serialization/PlaceholderLibrary.hpp | 13 +- .../Tensile/Serialization/Predicates.hpp | 9 +- Tensile/Source/lib/source/AMDGPU.cpp | 9 +- .../Source/lib/source/ContractionProblem.cpp | 45 +- .../Source/lib/source/ContractionSolution.cpp | 392 ++- Tensile/Source/lib/source/Debug.cpp | 5 + Tensile/Source/lib/source/hip/HipHardware.cpp | 10 +- .../lib/source/hip/HipSolutionAdapter.cpp | 16 +- Tensile/Source/lib/source/ocl/OclHardware.cpp | 8 +- Tensile/TensileCreateLibrary.py | 5 +- Tensile/Tests/emulation/mfma/1LDSB.yaml | 9 +- .../custom_kernel/ck_dgemm_90a_nn.yaml | 4 + .../ck_dgemm_90a_nn_large_offset.yaml | 4 + .../extended/direct_to_vgpr/dtv_hgemm.yaml | 6 +- .../extended/direct_to_vgpr/dtv_igemm.yaml | 3 +- .../local_split_u/f8gemm_lsu_mfma.yaml | 2 + .../local_split_u/igemm_lsu_mfma.yaml | 2 + .../local_split_u/sgemm_lsu_mfma.yaml | 4 +- .../extended/stream_k/sk_2tile_hgemm_hhs.yaml | 88 + .../extended/stream_k/sk_2tile_sgemm.yaml | 119 + .../Tests/extended/stream_k/sk_hgemm_hhs.yaml | 88 + Tensile/Tests/extended/stream_k/sk_sgemm.yaml | 119 + .../direct_to_vgpr/dtv_sgemm_lite.yaml | 1 - Tensile/Tests/pre_checkin/mfma/1LDSB.yaml | 9 +- .../pre_checkin/mfma/dgemm_gb_global_ldd.yaml | 3 +- .../pre_checkin/mfma/wider_local_read.yaml | 18 +- .../Tests/pre_checkin/wmma/hgemm_wmma.yaml | 3 +- .../wmma/hpa_bfloat16_gemm_wmma.yaml | 3 +- .../pre_checkin/wmma/hpa_hgemm_wmma.yaml | 3 +- .../pre_checkin/wmma/hpa_igemm_wmma.yaml | 3 +- Tensile/Tests/unit/test_HardwarePredicates.py | 48 +- Tensile/__init__.py | 2 +- Tensile/cmake/TensileConfigVersion.cmake | 2 +- bump-version.sh | 4 +- pytest.ini | 1 + .../automation/rocblas-benchInputCreator.py | 53 +- tuning_docs/tensile_tuning.tex | 2 +- 72 files changed, 4058 insertions(+), 847 deletions(-) create mode 100644 Tensile/KernelWriterStreamKInit.py create mode 100644 Tensile/Tests/extended/stream_k/sk_2tile_hgemm_hhs.yaml create mode 100644 Tensile/Tests/extended/stream_k/sk_2tile_sgemm.yaml create mode 100644 Tensile/Tests/extended/stream_k/sk_hgemm_hhs.yaml create mode 100644 Tensile/Tests/extended/stream_k/sk_sgemm.yaml diff --git a/.jenkins/common.groovy b/.jenkins/common.groovy index e4ee79bb4..96561b377 100644 --- a/.jenkins/common.groovy +++ b/.jenkins/common.groovy @@ -34,7 +34,7 @@ def runCompileCommand(platform, project, jobName, boolean debug=false) // Do release build of HostLibraryTests on CI until it is upgraded to rocm 5.3 to // avoid bug causing long build times of certain files. String buildType = 'Release' // debug ? 'Debug' : 'RelWithDebInfo' - String parallelJobs = "export HIPCC_COMPILE_FLAGS_APPEND=-parallel-jobs=2" + String parallelJobs = "export HIPCC_COMPILE_FLAGS_APPEND='-O3 -Wno-format-nonliteral -parallel-jobs=4'" // comment @@ -63,7 +63,12 @@ def runCompileCommand(platform, project, jobName, boolean debug=false) export PATH=/opt/rocm/bin:\$PATH cmake -DCMAKE_BUILD_TYPE=${buildType} -DCMAKE_CXX_COMPILER=${compiler} -DTensile_ROOT=\$(pwd)/../Tensile ../HostLibraryTests - make -j\$(nproc) + NPROC_BUILD=16 + if [ `nproc` -lt 16 ] + then + NPROC_BUILD=`nproc` + fi + make -j\$NPROC_BUILD popd """ diff --git a/.jenkins/integration.groovy b/.jenkins/integration.groovy index c469ef300..b2c678559 100644 --- a/.jenkins/integration.groovy +++ b/.jenkins/integration.groovy @@ -45,7 +45,7 @@ def runCI = boolean formatCheck = false - prj.timeout.test = 90 + prj.timeout.test = 150 prj.defaults.ccache = false def commonGroovy diff --git a/CHANGELOG.md b/CHANGELOG.md index 00266c990..e34e41583 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,44 @@ # Change Log for Tensile -## (Unreleased) Tensile 4.39.0 +## (Unreleased) Tensile 4.40.0 +### Additions +- new DisableKernelPieces values to invalidate local read, local write, and global read +- stream-K kernel generation, including two-tile stream-k algorithm by setting StreamK=3 +- feature to allow testing stream-k grid multipliers +- debug output to check occupancy for Stream-K +- reject condition for FractionalLoad + DepthU!=power of 2 +- new TENSILE_DB debugging value to dump the common kernel parameters +- predicate for APU libs +- new parameter (ClusterLocalRead) to turn on/off wider local read opt for TileMajorLDS +- new parameter (ExtraLatencyForLR) to add extra interval between local read and wait +- new logic to check LDS size with auto LdsPad(=1) and change LdsPad to 0 if LDS overflows +- initialization type and general batched options to the rocblas-bench input creator script + +### Optimizations +- enabled MFMA + LocalSplitU=4 for MT16x16 +- enabled (DirectToVgpr + MI4x4) and supported skinny MacroTile +- optimized postGSU kernel: separate postGSU kernels for different GSU values, loop unroll for GSU loop, wider global load depending on array size, and parallel reduction depending on array size +- auto LdsPad calculation for TileMajorLds + MI16x16 +- auto LdsPad calculation for UnrollMajorLds + MI16x16 + VectorWidth + +### Changes +- cleared hipErrorNotFound error since it is an expected part of the search +- modified hipcc search path for Linux +- changed PCI ID from 32bit to 64bit for ROCm SMI HW monitor +- changed LdsBlockSizePerPad to LdsBlockSizePerPadA, B to specify LBSPP separately +- changed the default value of LdsPadA, B, LdsBlockSizePerPadA, B from 0 to -1 +- updated test cases according to parameter changes for LdsPad, LBSPP and ClusterLocalRead +- Replaced std::regex with fnmatch()/PathMatchSpec as a workaround to std::regex stack overflow known bug + +### Fixes +- hipcc compile append flag parallel-jobs=4 +- race condition in Stream-K that appeared with large grids and small sizes +- mismatch issue with LdsPad + LdsBlockSizePerPad!=0 and TailLoop +- mismatch issue with LdsPad + LdsBlockSizePerPad!=0 and SplitLds +- incorrect reject condition check for DirectToLds + LdsBlockSizePerPad=-1 case +- small fix for LdsPad optimization (LdsElement calculation) + +## Tensile 4.39.0 for ROCm 6.0 ### Added - Added aquavanjaram support: gfx940/gfx941/gfx942, fp8/bf8 datatype, xf32 datatype, and stochastic rounding for various datatypes - Added/updated tuning scripts diff --git a/HostLibraryTests/CMakeLists.txt b/HostLibraryTests/CMakeLists.txt index 97ef3cc1c..6a67cf7c1 100644 --- a/HostLibraryTests/CMakeLists.txt +++ b/HostLibraryTests/CMakeLists.txt @@ -59,7 +59,7 @@ if(TENSILE_STATIC_ONLY) endif() if(NOT Tensile_FOUND) - find_package(Tensile 4.39.0 EXACT REQUIRED ${TENSILE_COMPONENTS} PATHS "${CMAKE_CURRENT_SOURCE_DIR}/../Tensile") + find_package(Tensile 4.40.0 EXACT REQUIRED ${TENSILE_COMPONENTS} PATHS "${CMAKE_CURRENT_SOURCE_DIR}/../Tensile") endif() if(NOT TENSILE_DISABLE_CTEST) diff --git a/HostLibraryTests/CachingLibrary_test.cpp b/HostLibraryTests/CachingLibrary_test.cpp index 48bb7826e..0d312718c 100644 --- a/HostLibraryTests/CachingLibrary_test.cpp +++ b/HostLibraryTests/CachingLibrary_test.cpp @@ -255,8 +255,8 @@ TEST(Hashing, AMDGPU) for(int c1 : counts) for(int c2 : counts) { - AMDGPU g1(p1, c1, "g1"); - AMDGPU g2(p2, c2, "g2"); + AMDGPU g1(p1, c1, 0, "g1"); + AMDGPU g2(p2, c2, 0, "g2"); if(p1 != p2 || c1 != c2) { diff --git a/HostLibraryTests/ContractionSelectionLibrary_test.cpp b/HostLibraryTests/ContractionSelectionLibrary_test.cpp index 826c05014..161952622 100644 --- a/HostLibraryTests/ContractionSelectionLibrary_test.cpp +++ b/HostLibraryTests/ContractionSelectionLibrary_test.cpp @@ -39,7 +39,7 @@ using namespace Tensile; TEST(ContractionSelectionLibraryTest, Single) { std::shared_ptr hardware = std::make_shared( - AMDGPU::Processor::gfx900, 64, "AMD Radeon Vega Frontier Edition"); + AMDGPU::Processor::gfx900, 64, 0, "AMD Radeon Vega Frontier Edition"); SingleContractionLibrary lib; @@ -53,11 +53,11 @@ TEST(ContractionSelectionLibraryTest, Single) TEST(ContractionSelectionLibraryTest, GPUSelection) { std::shared_ptr v10 = std::make_shared( - AMDGPU::Processor::gfx900, 64, "AMD Radeon Vega Frontier Edition"); + AMDGPU::Processor::gfx900, 64, 0, "AMD Radeon Vega Frontier Edition"); std::shared_ptr v20 - = std::make_shared(AMDGPU::Processor::gfx906, 60, "AMD Radeon Vega 7"); + = std::make_shared(AMDGPU::Processor::gfx906, 60, 0, "AMD Radeon Vega 7"); std::shared_ptr v20_64CU - = std::make_shared(AMDGPU::Processor::gfx906, 64, "AMD Radeon Vega 7"); + = std::make_shared(AMDGPU::Processor::gfx906, 64, 0, "AMD Radeon Vega 7"); // Create solutions auto v20Solution = std::make_shared(); diff --git a/HostLibraryTests/ProjectedPerformance_test.cpp b/HostLibraryTests/ProjectedPerformance_test.cpp index 71c305fad..43c7195f4 100644 --- a/HostLibraryTests/ProjectedPerformance_test.cpp +++ b/HostLibraryTests/ProjectedPerformance_test.cpp @@ -52,12 +52,12 @@ std::map makeIdeals() return ideals; } -ContractionSolution::SizeMapping makeSizeMapping(Tensile::dim3 workGroupSize, - Tensile::dim3 threadTile, - Tensile::dim3 macroTile, - size_t globalSplitU) +SizeMapping makeSizeMapping(Tensile::dim3 workGroupSize, + Tensile::dim3 threadTile, + Tensile::dim3 macroTile, + size_t globalSplitU) { - ContractionSolution::SizeMapping sizeMapping; + SizeMapping sizeMapping; sizeMapping.workGroupSize = workGroupSize; sizeMapping.threadTile = threadTile; @@ -93,15 +93,14 @@ TEST(ContractionPerformance, Problem1) Tensile::dim3 macroTile = Tensile::dim3(64, 64, 16); size_t globalSplitU = 1; - ContractionSolution::SizeMapping sizeMapping - = makeSizeMapping(workgroupSize, threadTile, macroTile, globalSplitU); + SizeMapping sizeMapping = makeSizeMapping(workgroupSize, threadTile, macroTile, globalSplitU); solution->sizeMapping = sizeMapping; auto problem = ContractionProblem::GEMM(false, false, 1536, 1536, 64, 1536, 64, 1536, 1.5, false, 1.0); - AMDGPU hardware(Tensile::AMDGPU::Processor::gfx906, 64, "gfx906"); + AMDGPU hardware(Tensile::AMDGPU::Processor::gfx906, 64, 0, "gfx906"); double perf = solution->projectedPerformance(problem, hardware).speedGFlops; ASSERT_DOUBLE_EQ(perf, 3000.0); @@ -122,15 +121,14 @@ TEST(ContractionPerformance, Problem2) Tensile::dim3 macroTile = Tensile::dim3(64, 64, 16); size_t globalSplitU = 1; - ContractionSolution::SizeMapping sizeMapping - = makeSizeMapping(workgroupSize, threadTile, macroTile, globalSplitU); + SizeMapping sizeMapping = makeSizeMapping(workgroupSize, threadTile, macroTile, globalSplitU); solution->sizeMapping = sizeMapping; auto problem = ContractionProblem::GEMM(false, false, 384, 192, 60, 384, 60, 384, 1.5, false, 1.0); - AMDGPU hardware(Tensile::AMDGPU::Processor::gfx906, 64, "gfx906"); + AMDGPU hardware(Tensile::AMDGPU::Processor::gfx906, 64, 0, "gfx906"); double perf = solution->projectedPerformance(problem, hardware).speedGFlops; ASSERT_DOUBLE_EQ(perf, 843.75); @@ -151,15 +149,14 @@ TEST(ContractionPerformance, Problem3) Tensile::dim3 macroTile = Tensile::dim3(128, 128, 16); size_t globalSplitU = 1; - ContractionSolution::SizeMapping sizeMapping - = makeSizeMapping(workgroupSize, threadTile, macroTile, globalSplitU); + SizeMapping sizeMapping = makeSizeMapping(workgroupSize, threadTile, macroTile, globalSplitU); solution->sizeMapping = sizeMapping; auto problem = ContractionProblem::GEMM(false, false, 384, 192, 60, 384, 60, 384, 1.5, false, 1.0); - AMDGPU hardware(Tensile::AMDGPU::Processor::gfx906, 64, "gfx906"); + AMDGPU hardware(Tensile::AMDGPU::Processor::gfx906, 64, 0, "gfx906"); auto model = solution->projectedPerformance(problem, hardware); // std::cout << model << "\n"; @@ -184,15 +181,14 @@ TEST(ContractionPerformance, Problem4) Tensile::dim3 macroTile = Tensile::dim3(128, 64, 16); size_t globalSplitU = 4; - ContractionSolution::SizeMapping sizeMapping - = makeSizeMapping(workgroupSize, threadTile, macroTile, globalSplitU); + SizeMapping sizeMapping = makeSizeMapping(workgroupSize, threadTile, macroTile, globalSplitU); solution->sizeMapping = sizeMapping; auto problem = ContractionProblem::GEMM(false, false, 1536, 1575, 64, 1536, 64, 1536, 1.5, false, 3.0); - AMDGPU hardware(Tensile::AMDGPU::Processor::gfx906, 64, "gfx906"); + AMDGPU hardware(Tensile::AMDGPU::Processor::gfx906, 64, 0, "gfx906"); auto model = solution->projectedPerformance(problem, hardware); // std::cout << model << "\n"; diff --git a/HostLibraryTests/llvm/LibraryPerformance_test.cpp b/HostLibraryTests/llvm/LibraryPerformance_test.cpp index 5b2a6751e..b72721be2 100644 --- a/HostLibraryTests/llvm/LibraryPerformance_test.cpp +++ b/HostLibraryTests/llvm/LibraryPerformance_test.cpp @@ -285,8 +285,8 @@ std::vector GetLibraries(std::string const& e { std::vector rv; - std::vector gpus{AMDGPU(AMDGPU::Processor::gfx900, 64, "Vega 10"), - AMDGPU(AMDGPU::Processor::gfx906, 64, "Vega 20")}; + std::vector gpus{AMDGPU(AMDGPU::Processor::gfx900, 64, 0, "Vega 10"), + AMDGPU(AMDGPU::Processor::gfx906, 64, 0, "Vega 20")}; for(auto const& gpu : gpus) { @@ -298,9 +298,9 @@ std::vector GetLibraries(std::string const& e } rv.push_back(std::make_tuple( - AMDGPU(AMDGPU::Processor::gfx908, 64, "Arcturus"), "rocBLAS_Full." + ext, false, true)); + AMDGPU(AMDGPU::Processor::gfx908, 64, 0, "Arcturus"), "rocBLAS_Full." + ext, false, true)); rv.push_back(std::make_tuple( - AMDGPU(AMDGPU::Processor::gfx1010, 40, "Navi"), "KernelsLiteNavi." + ext, true, false)); + AMDGPU(AMDGPU::Processor::gfx1010, 40, 0, "Navi"), "KernelsLiteNavi." + ext, true, false)); return rv; } diff --git a/Tensile/Common.py b/Tensile/Common.py index dfabaefe0..f69098a3b 100644 --- a/Tensile/Common.py +++ b/Tensile/Common.py @@ -361,7 +361,7 @@ def getArchitectureName(gfxName): validMFMA["I8_940"] = [[32,32,4,2], [32,32,16,1], [16,16,4,4], [16,16,32,1], [4,4,4,16]] validMFMA["I8"] = validMFMA["H"] + validMFMA["F8"] validWMMA = [[16,16,16,1], ] -validTT = 16 +validTT = 64 validMFMA["_format9"] = [] for MFMA in [validMFMA["H"], validMFMA["S"], validMFMA["B"], validMFMA["D"], validMFMA["X"], validMFMA["F8"], validWMMA]: @@ -796,6 +796,9 @@ def getArchitectureName(gfxName): # - Can vectorize stores in edge tiles. Vector width can be up to AF0EM. # (since C matrix is always coalesced in Free0 index direction and this assertion guarantees the index element multiple) # + # TailLoop Optimizations: + # - enable wider global load with AF0EM > 1 for A + TLU, AF1EM > 1 for B + TLU + # # 1 indicates no assertion (since all sizes are multiples of 1) "AssertFree0ElementMultiple" : [1,2,4,8,16], @@ -1080,9 +1083,18 @@ def getArchitectureName(gfxName): # 6= +NoMAC # 7= +NoPreLoop+ NoGlobalReadInc # 9= NullKernel + # 10= +invalid LocalReadA (use invalid vgpr offset(LdsOOB)). Negative only. + # 11= +invalid LocalReadB (use invalid vgpr offset(LdsOOB)). Negative only. + # 12= +invalid LocalReadA+B (use invalid vgpr offset(LdsOOB)). Negative only. + # 13= +invalid LocalWriteA (use invalid vgpr offset(LdsOOB)). Negative only. + # 14= +invalid LocalWriteB (use invalid vgpr offset(LdsOOB)). Negative only. + # 15= +invalid LocalWriteA+B (use invalid vgpr offset(LdsOOB)). Negative only. + # 16= +invalid GlobalReadA (use srdA[2]=0, BufferLoad only). Negative only. + # 17= +invalid GlobalReadB (use srdB[2]=0, BufferLoad only). Negative only. + # 18= +invalid GlobalReadA+B (use srdB[2]=0, BufferLoad only). Negative only. # For example set DisableKernelPieces: [0,1,2,3,4,5,6,7,9] # this will create a set of kernels with progressively more pieces of the kernel disabled - "DisableKernelPieces": list(range(-9,10)), # disable pieces of the kernel, for performance isolation + "DisableKernelPieces": list(range(-18,10)), # disable pieces of the kernel, for performance isolation # assume atomics always work correctly. "DisableAtomicFail": [False, True], @@ -1092,6 +1104,15 @@ def getArchitectureName(gfxName): # fp16 alternate implementation round mode: false for truncate, true for round near zero "Fp16AltImplRound": [False, True], + # StreamK kernels divide work evenly among CUs by splitting along MT and K dimensions + # Total work units are calculated as (#MTs x #LoopIters) and divided among workgroups + # In most cases each workgroup will calculate a partial tile that are accumulated in a fixup step in the same kernel + # 0: Standard data-parallel kernel + # 1: Basic StreamK atomic (uses atomics to accumulate partial tiles) + # 2: Basic StreamK non-atomic (uses workspace to store partial tiles, accumulate in deterministic fix-up step) + # 3: Two-Tile StreamK (non-atomic, each WG completes an even number of sk iterations, followed by an even number of dp tiles) + "StreamK": [0, 1, 2, 3], + # 0 : standard launch # N>0 : launch persistent kernel with N workgroups per compute unit # - Recommended min is enough WG to use all resources on the CU @@ -1195,8 +1216,8 @@ def getArchitectureName(gfxName): # place upper and lower limits on the skinny-ness of macro tiles; shape=1 means square tile, like 64x64. shape=4 means 4x64 or 64x4 or 128x8... # these will just mark some kernels as invalid so that fewer kernels will be checked - "MacroTileShapeMin": list(range(1, 256+1)), - "MacroTileShapeMax": list(range(1, 256+1)), + "MacroTileShapeMin": list(range(1, 512+1)), + "MacroTileShapeMax": list(range(1, 512+1)), # when loading all the data from global into lds requires multiple load instructions, these parameters govern which # loads will pull which rectangle of data from global into lds @@ -1243,15 +1264,19 @@ def getArchitectureName(gfxName): # performance so this has been deprecated and probably doesn't work # -1 means use same padding as the VectorWidth if TLU=0 else 0. (Padding only helps when transpose is required) # With MatrixInstruction: -1 means max(GRVW,MIInput) if TLU=0 - "LdsPadA": [ -1, 0, 1, 2, 3, 4, 8, 16, 32], - "LdsPadB": [ -1, 0, 1, 2, 3, 4, 8, 16, 32], + # SourceKernel case, convert -1 to 0. Please manually set LdsPad for SourceKernel + # SourceKernel requires LdsPadA==LdsPadB + "LdsPadA": list(range(-1, 128)), + "LdsPadB": list(range(-1, 128)), # Padding boundary for LDS. defines block-size for pad insertion. for every 'LdsBlockSizePerPad' bytes, LDS padding (pad value from LdsPad parameter) # is added (readOffset aware of the pad and adjusts offset value based on this parameter value). # Only support LdsBlockSizePerPad >= unrollDepth * BPE # 0 means disable LdsBlockSizePerPad, # -1 means round up to nearest power of 2 begin with 128 - "LdsBlockSizePerPad": [-1, 0, 64, 128, 256, 512, 1024], + # SourceKernel case, convert -1 to 0. Please manually set LdsBlockSizePerPad for SourceKernel + "LdsBlockSizePerPadA": [-1, 0, 64, 128, 256, 512, 1024, 2048, 4096], + "LdsBlockSizePerPadB": [-1, 0, 64, 128, 256, 512, 1024, 2048, 4096], # Transpose LDS format. Local store in Coalesced dimension , same as optimized global fetch dimension . applicable only in TLU=0 case for miSIMD(s) # TODO: No code for -1 ? @@ -1269,6 +1294,9 @@ def getArchitectureName(gfxName): # No need to increase miLatencyLeft in that case. "ExtraMiLatencyLeft": list(range(0,9,2)), + # Add extra latency to calculate number of MFMA to insert between local read and wait + "ExtraLatencyForLR": list(range(0,17,2)), + # Allocate dedicated vgpr for local read with packing # False: use tmp vgpr. Less vgpr usage, but not best for local read scheduling # True: use dedicated vgpr for local read with packing. Best for local read scheduling, but need more vgpr @@ -1277,6 +1305,10 @@ def getArchitectureName(gfxName): # Not effective for PrefetchLocalRead <= 1 "VgprForLocalReadPacking": [False, True], + # ClusterLocalRead enables wider local read and packing with v_perm_b32 for 8bit or 16bit data + # Works with VgprForLocalReadPacking=True + "ClusterLocalRead": [False, True], + # tinkered with adding extra syncs or waits in the assembly kernels to see if it would improve the sequencing between workgroups, "fully synchronous scheduling" is WAY more promising; this can be deprecated "PerformanceSyncLocation": list(range(-1, 16*16+1)), "PerformanceWaitLocation": list(range(-1, 16*16+1)), @@ -1364,14 +1396,17 @@ def getArchitectureName(gfxName): {"LocalDotLayout": [ 1 ] }, {"AggressivePerfMode": [ 1 ] }, {"KernelLanguage": [ "Source" ] }, - {"LdsPadA": [ 0 ] }, - {"LdsPadB": [ 0 ] }, - {"LdsBlockSizePerPad": [ 0 ] }, + {"LdsPadA": [ -1 ] }, + {"LdsPadB": [ -1 ] }, + {"LdsBlockSizePerPadA": [ -1 ] }, + {"LdsBlockSizePerPadB": [ -1 ] }, {"TransposeLDS": [ 0 ] }, {"UnrollMajorLDSA": [ False ] }, {"UnrollMajorLDSB": [ False ] }, {"ExtraMiLatencyLeft": [ 0 ] }, + {"ExtraLatencyForLR": [ 0 ] }, {"VgprForLocalReadPacking": [ False ] }, + {"ClusterLocalRead": [ False ] }, {"MaxOccupancy": [ 40 ] }, {"VectorWidth": [ -1 ] }, {"VectorStore": [ -1 ] }, @@ -1448,6 +1483,7 @@ def getArchitectureName(gfxName): {"GlobalSplitUAtomicAdd": [ False ] }, {"MacroTileShapeMin": [ 1 ] }, {"MacroTileShapeMax": [ 64 ] }, + {"StreamK": [ 0 ] }, {"PersistentKernel": [ 0 ] }, {"PersistentKernelAlongBatch":[ False ] }, # May be default True is better ? {"PackBatchDims": [ 0 ] }, @@ -1722,8 +1758,8 @@ def getArchitectureName(gfxName): "Fp32toFp8SWClip" : True, # only in-device SR for now - "StochasticRounding" : False # By default, IEEE RNE rounding - + "StochasticRounding" : False, # By default, IEEE RNE rounding + # Rounding mode for f32 to f8 down conversion # TODO in Future: # There are two different rounding modes for f32 to f8 down conversion: [0]: IEEE RNE mode and [1/2]: stochastic mode. @@ -2186,6 +2222,10 @@ def assignGlobalParameters( config ): if os.name == "nt": globalParameters["CurrentISA"] = (9,0,6) printWarning("Failed to detect ISA so forcing (gfx906) on windows") + if globalParameters["CurrentISA"] == (9,4,2) or globalParameters["CurrentISA"] == (11,0,0) or \ + globalParameters["CurrentISA"] == (11,0,1) or globalParameters["CurrentISA"] == (11,0,2): + printWarning("HardwareMonitor currently disabled for gfx942 or gfx1100/gfx1101/gfx1102") + globalParameters["HardwareMonitor"] = False # For ubuntu platforms, call dpkg to grep the version of hip-clang. This check is platform specific, and in the future # additional support for yum, dnf zypper may need to be added. On these other platforms, the default version of diff --git a/Tensile/Components/Signature.py b/Tensile/Components/Signature.py index 196055981..531860a5a 100644 --- a/Tensile/Components/Signature.py +++ b/Tensile/Components/Signature.py @@ -207,6 +207,9 @@ def __call__(self, writer): kStr += self.addArgument( 'C', '8', offset, "global_buffer", dstValueType, "generic"); offset += 8 kStr += self.addArgument( 'A', '8', offset, "global_buffer", srcValueTypeA, "generic"); offset += 8 kStr += self.addArgument( 'B', '8', offset, "global_buffer", srcValueTypeB, "generic"); offset += 8 + if kernel["StreamK"] == 2 or kernel["StreamK"] == 3: + kStr += self.addArgument( 'WS', '8', offset, "global_buffer", dstValueType, "generic"); offset += 8 + kStr += self.addArgument( 'Flags', '8', offset, "global_buffer", dstValueType, "generic"); offset += 8 kStr += self.addArgument("OffsetD", '8', offset, "by_value", "u64"); offset += 8 kStr += self.addArgument("OffsetC", '8', offset, "by_value", "u64"); offset += 8 @@ -264,6 +267,10 @@ def __call__(self, writer): kStr += self.addArgument( "NumWorkGroups0", '4', offset, "by_value", "u32"); offset += 4 kStr += self.addArgument( "NumWorkGroups1", '4', offset, "by_value", "u32"); offset += 4 + if kernel["StreamK"]: + kStr += self.addArgument("MagicNumberProblemNumGroupTiles0", '4', offset, "by_value", "u32"); offset += 4 + kStr += self.addArgument("MagicShiftProblemNumGroupTiles0", '4', offset, "by_value", "u32"); offset += 4 + if kernel["PersistentKernel"]: kStr += self.addArgument("MagicNumberProblemNumGroupTiles0", '4', offset, "by_value", "u32"); offset += 4 kStr += self.addArgument("MagicShiftProblemNumGroupTiles0", '4', offset, "by_value", "u32"); offset += 4 @@ -273,6 +280,18 @@ def __call__(self, writer): kStr += self.addArgument("MagicNumProblemNumGroupTiles0By1", '4', offset, "by_value", "u32"); offset += 4 kStr += self.addArgument("MagicShiftProblemNumGroupTiles0By1", '4', offset,"by_value", "u32"); offset += 4 + if kernel["StreamK"]: + kStr += self.addArgument("ItersPerTile", '4', offset,"by_value", "u32"); offset += 4 + kStr += self.addArgument("MagicNumberItersPerTile", '4', offset,"by_value", "u32"); offset += 4 + kStr += self.addArgument("MagicShiftItersPerTile", '4', offset,"by_value", "u32"); offset += 4 + kStr += self.addArgument("TotalIters", '4', offset,"by_value", "u32"); offset += 4 + kStr += self.addArgument("SKItersPerWG", '4', offset,"by_value", "u32"); offset += 4 + if kernel["StreamK"] == 3: # Two-tile SK + kStr += self.addArgument("skGrid", '4', offset,"by_value", "u32"); offset += 4 + kStr += self.addArgument("skTiles", '4', offset,"by_value", "u32"); offset += 4 + kStr += self.addArgument("skExtraIters", '4', offset,"by_value", "u32"); offset += 4 + # kStr += self.addArgument("dpTilesPerWG", '4', offset,"by_value", "u32"); offset += 4 + kStr += self.addArgument( "NumFullBlocks", '4', offset, "by_value", "u32"); offset += 4 kStr += self.addArgument( "WgmRemainder1", '4', offset, "by_value", "u32"); offset += 4 kStr += self.addArgument( "MagicNumberWgmRemainder1", '4', offset, "by_value", "u32"); offset += 4 diff --git a/Tensile/Contractions.py b/Tensile/Contractions.py index 6aa4615c3..7176fccd1 100644 --- a/Tensile/Contractions.py +++ b/Tensile/Contractions.py @@ -383,6 +383,10 @@ def CompoundPredicates(cls, state, problemType): if ('_GlobalAccumulation' not in state) or (state['_GlobalAccumulation'] != 'MultipleBuffer'): rv += [cls("DeterministicMode", value = False)] + if ('StreamK' in state) and (state['StreamK'] == 1): + # StreamK = 1 uses atomic for partial tiles + rv += [cls("DeterministicMode", value = False)] + # debugging: mark this set to allow the problem always runnable with PK if 'PersistentKernel' in state and state['PersistentKernel']: rv += [cls("PersistentKernelCheck")] @@ -421,7 +425,7 @@ def CompoundPredicates(cls, state, problemType): val = min(val, state["AssertSizeLessThan"][1] - 1) rv += [cls('BufferStoreOffsetLimitCheck', value=val)] - if '_GlobalAccumulation' in state and state['_GlobalAccumulation'] != None: + if '_GlobalAccumulation' in state and state['_GlobalAccumulation'] != None and not state["StreamK"]: value = state['MinKForGSU'] * state['GlobalSplitU'] rv += [cls('GlobalSplitUCheckMinK', value=value)] @@ -448,6 +452,7 @@ class SizeMapping: 'packSummationDims', 'packBatchDims', 'magicDivAlg', + 'streamK', 'persistentKernel', 'persistentKernelAlongBatch', 'sourceKernel', @@ -462,6 +467,8 @@ def FromOriginalState(cls, d): globalAccum = 1 if d['_GlobalAccumulation'] == 'MultipleBuffer': globalAccum = 2 + if d['_GlobalAccumulation'] == 'PartialsBuffer': + globalAccum = 3 return cls(workGroup = d['WorkGroup'], macroTile = cls.ReadOriginalMacroTile(d), threadTile = d['ThreadTile'], @@ -472,6 +479,7 @@ def FromOriginalState(cls, d): staggerStrideShift = d['_staggerStrideShift'] if '_staggerStrideShift' in d else 0, packSummationDims = d['PackSummationDims'] if 'PackSummationDims' in d else 0, packBatchDims = d['PackBatchDims'] if 'PackBatchDims' in d else 0, + streamK = d['StreamK'] if 'StreamK' in d else 0, persistentKernel = d['PersistentKernel'] if 'PersistentKernel' in d else 0, persistentKernelAlongBatch = d['PersistentKernelAlongBatch'] if 'PersistentKernelAlongBatch' in d else False, magicDivAlg = d.get('MagicDivAlg', 1), diff --git a/Tensile/Hardware.py b/Tensile/Hardware.py index 06199fff6..01a591af9 100644 --- a/Tensile/Hardware.py +++ b/Tensile/Hardware.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -33,50 +33,67 @@ def FromISA(cls, isa): return cls("AMDGPU", value=cls("Processor", value=gfxArch)) @classmethod - def FromHardware(cls, isa, cuCount=None): + def FromHardware(cls, isa, cuCount=None, isAPU=None): gfxArch = Common.gfxName(isa) - if cuCount == None: + if cuCount == None and isAPU == None: return cls("AMDGPU", value=cls("Processor", value=gfxArch)) - else: + elif cuCount == None: + return cls("AMDGPU", value=cls.And([cls("Processor", value=gfxArch), + cls("IsAPU", value=isAPU)])) + elif isAPU == None: return cls("AMDGPU", value=cls.And([cls("Processor", value=gfxArch), cls("CUCount", value=cuCount)])) + else: + return cls("AMDGPU", value=cls.And([cls("Processor", value=gfxArch), + cls("CUCount", value=cuCount), + cls("IsAPU", value=isAPU)])) def __lt__(self, other): # Use superclass logic for TruePreds if other.tag == 'TruePred' or self.tag == 'TruePred': return super().__lt__(other) - # Compute unit counts are embedded as 'And' with - # 'Processor' and 'ComputeUnitCount' as children + # Compute unit counts or APU/XPU versions are embedded as 'And' with + # 'Processor', 'CUCount', and 'IsAPU' as children if self.value.tag == 'And': myAndPred = self.value myProcPred = next(iter(x for x in myAndPred.value if x.tag == "Processor"), None) myCUPred = next(iter(x for x in myAndPred.value if x.tag == "CUCount"), None) myCUCount = myCUPred.value if myCUPred != None else 0 + myIsAPUPred = next(iter(x for x in myAndPred.value if x.tag == "IsAPU"), None) + myIsAPU = myIsAPUPred.value if myIsAPUPred != None else -1 else: myProcPred = self.value myCUCount = 0 + myIsAPU = -1 if other.value.tag == 'And': otherAndPred = other.value otherProcPred = next(iter(x for x in otherAndPred.value if x.tag == "Processor"), None) otherCUPred = next(iter(x for x in otherAndPred.value if x.tag == "CUCount"), None) otherCUCount = otherCUPred.value if otherCUPred != None else 0 + otherIsAPUPred = next(iter(x for x in otherAndPred.value if x.tag == "IsAPU"), None) + otherIsAPU = otherIsAPUPred.value if otherIsAPUPred != None else -1 else: otherProcPred = other.value otherCUCount = 0 + otherIsAPU = -1 - # If CU properties are empty, then compare processor predicates - if myCUCount == otherCUCount == 0: - # Make sure that we have valid processor preds - assert myProcPred != None and otherProcPred != None, "Missing processor predicate" - assert myProcPred.tag == otherProcPred.tag == "Processor", "Invalid processor predicate" + # If APU properties are the same, then check CU count or architecture + if myIsAPU == otherIsAPU: + # If CU properties are empty, then compare processor predicates + if myCUCount == otherCUCount == 0: + # Make sure that we have valid processor preds + assert myProcPred != None and otherProcPred != None, "Missing processor predicate" + assert myProcPred.tag == otherProcPred.tag == "Processor", "Invalid processor predicate" - # Downgrade to base class so that we don't recurse - myProcPredCopy = copy.deepcopy(myProcPred) - otherProcPredCopy = copy.deepcopy(otherProcPred) - myProcPredCopy.__class__ = otherProcPredCopy.__class__ = Properties.Predicate - return myProcPredCopy < otherProcPredCopy + # Downgrade to base class so that we don't recurse + myProcPredCopy = copy.deepcopy(myProcPred) + otherProcPredCopy = copy.deepcopy(otherProcPred) + myProcPredCopy.__class__ = otherProcPredCopy.__class__ = Properties.Predicate + return myProcPredCopy < otherProcPredCopy - # Higher priority given to higher CU count - return myCUCount > otherCUCount + # Higher priority given to higher CU count + return myCUCount > otherCUCount + # APU sorted before XPU, and XPU sorted before generic + return myIsAPU > otherIsAPU diff --git a/Tensile/KernelWriter.py b/Tensile/KernelWriter.py index 2700c94a2..7cfea8efb 100644 --- a/Tensile/KernelWriter.py +++ b/Tensile/KernelWriter.py @@ -141,6 +141,8 @@ def countNumMfmaForCurrentOrNextLoopLR(self, kernel, tensorParametersA, tensorPa latencyForLR -= max(latencyLeft,0) # remaining latency in mfma if not curr: latencyForLR -= self.miLatency # last LR will have 1 mfma latency + # add extra latency + latencyForLR += kernel["ExtraLatencyForLR"] while latencyForLR > 0: latencyForLR -= self.miLatency latencyForLRCount += 1 @@ -268,6 +270,9 @@ def makeSchedule(self, kernel, tensorParametersA, tensorParametersB, localWriteE self.numMfmaForNextLoopLR = min(self.numMfmaForNextLoopLR,numMfmaPerIter-1) self.barrierMfmaIndex = numMfmaPerIter*(kernel["LoopIters"]-self.numItersPLR+1) - self.numMfmaForNextLoopLR - 1 if self.numItersPLR else 0 numMfmaBetweenLWandBarrier = 2 if kernel["MatrixInstM"] == 32 else 3 + if self.miLatency <= 4 and kernel["LoopIters"] >= 4: + # low latency MFMA and enough number of loop iteration case, we double numMfmaBetweenLWandBarrier + numMfmaBetweenLWandBarrier *= 2 # set and adjust lwEndMfmaIndex self.setAndAdjustLwEndMfmaIndex(kernel, tensorParametersA, tensorParametersB, numMfmaBetweenLWandBarrier, lastLoop) @@ -1472,20 +1477,25 @@ def makeSubIterSchedule(self, kernel, localReadCode, iteration, pointerLWCode, p for j in range(instPerPack): iterCode.addCode(packItems.pop(0)) curPackIdx += 1 - if packItems: + # insert second packing code only if miLatencyLeft is large enough + if packItems and self.miLatencyLeft > 2: for j in range(instPerPack): iterCode.addCode(packItems.pop(0)) curPackIdx += 1 # since packed register need to wait 2 quad cycle to finish packing # we insert pack instruction if we can, or s_nop + count = 0 # count number of cycle for nop to insert while curPackIdx < numPack+2: if packItems: for j in range(instPerPack): iterCode.addCode(packItems.pop(0)) curPackIdx += 1 else: - iterCode.addInst("s_nop ","0","VALU packing writes to be consumed by matrix instruction") + count += 1 curPackIdx += 1 + if count: + # insert 1 nop instruction + iterCode.addInst("s_nop ",str(count - 1),"VALU packing writes to be consumed by matrix instruction") if i == numMfmaPerIter - 1: while packItems: iterCode.addCode(packItems.pop(0)) @@ -3064,8 +3074,8 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ): # These cases loop back and run the prefetch loop again # we need an extra barrier to ensure that the ds_reads (either for SR or MFMA) from previous iteration # have finished before we generate the prefetch for the next summation index. - if kernel["PersistentKernel"] or self.actualSummationLoops>1: - kl.append( self.indent + self.syncStr + "// for PersistentKernel " + self.endLine ) + if kernel["PersistentKernel"] or kernel["StreamK"] > 0 or self.actualSummationLoops>1: + kl.append( self.indent + self.syncStr + "// for PersistentKernel / StreamK " + self.endLine ) if self.enable["LocalWrite"]: # local write @@ -3321,6 +3331,8 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ): if self.enable["Sync"]: kl.append(self.syncThreads(kernel)) + kl.append(self.doneGlobalABReads(kernel)) + # the following read/write addresses could be modified in recalcLocal(Read|Write)Addresses due to policy change self.oriLraA = None # back up original local read address vgpr self.oriLraB = None @@ -3392,7 +3404,11 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ): mEnd = 1 if (kernel["DirectToVgprA"] or kernel["DirectToVgprB"] or kernel["DirectToLdsA"] or kernel["DirectToLdsB"]) \ and kernel["EnableMatrixInstruction"]: - mEnd = kernel["DepthU"]//(kernel["MatrixInstK"]*kernel["LocalSplitU"]) + mEnd = kernel["_DepthULds"]//(kernel["MatrixInstK"]*kernel["LocalSplitU"]) + elif kernel["EnableMatrixInstruction"] and \ + ((kernel["LdsPadA"] and kernel["LdsBlockSizePerPadA"]) or (kernel["LdsPadB"] and kernel["LdsBlockSizePerPadB"])): + # LdsPad + LBSPP case, address increment is not distributed uniformly. So, we need to unroll tail loop + mEnd = kernel["_DepthULds"]//(kernel["MatrixInstK"]*kernel["LocalSplitU"]) for mValue in range(mEnd): if mEnd > 1: @@ -3663,6 +3679,7 @@ def initKernel(self, kernel, tensorParametersA, tensorParametersB ): # - GlobalSplitU = 1 # GSU>1 case, remaining K is distributed unevenly and does not work with tailLoop in noLoadLoop # - PersistentKernel = 0 + # - StreamK = 0 # - DepthULdsDivisor = 1 # - StaggerU = 0 # StaggerU=0 case, we can exit NoLoadLoop earlier when whole K range is processed @@ -3709,7 +3726,7 @@ def initKernel(self, kernel, tensorParametersA, tensorParametersB ): elif kernel["BufferLoad"] and (not kernel["SuppressNoLoadLoop"]) and \ kernel["EnableMatrixInstruction"] and kernel["MatrixInstK"] > 1 and \ (glvwA <= 1 or (tailLoopLoadWidthA % glvwA == 0)) and (glvwB <= 1 or (tailLoopLoadWidthB % glvwB == 0)) and \ - gsu == 1 and kernel["PersistentKernel"] == 0 and kernel["DepthULdsDivisor"] == 1 and \ + gsu == 1 and kernel["PersistentKernel"] == 0 and kernel["StreamK"] == 0 and kernel["DepthULdsDivisor"] == 1 and \ kernel["InnerUnroll"] == 1: if kernel["StaggerU"] == 0: noTailLoop = 2 @@ -3752,6 +3769,12 @@ def initKernel(self, kernel, tensorParametersA, tensorParametersB ): self.enable["Sync"] = True and not (dkp>0 and dkp >= 5) and not dkp == -5 self.enable["MAC"] = True and not (dkp>0 and dkp >= 6) and not dkp == -6 self.enable["PostLoop"] = True and not (dkp>0 and dkp >= 1) and not dkp == -1 + self.enable["InvalidLocalReadA"] = dkp == -10 or dkp == -12 + self.enable["InvalidLocalReadB"] = dkp == -11 or dkp == -12 + self.enable["InvalidLocalWriteA"] = dkp == -13 or dkp == -15 + self.enable["InvalidLocalWriteB"] = dkp == -14 or dkp == -15 + self.enable["InvalidGlobalReadA"] = (dkp == -16 or dkp == -18) and kernel["BufferLoad"] + self.enable["InvalidGlobalReadB"] = (dkp == -17 or dkp == -18) and kernel["BufferLoad"] #if dkp: # print "\nKernelWriter enable:", self.enable @@ -3856,12 +3879,13 @@ def initKernel(self, kernel, tensorParametersA, tensorParametersB ): # EnableMatrixInstruction + MIInputPerThread > 1 # SourceSwap only (TODO: non SourceSwap) # VgprForLocalReadPacking (need dedicated vpgr for packing) + # ClusterLocalRead # not UnrollMajorLDS # VectorWidthA,B > 1 self.lrvwTileA = 1 self.lrvwTileB = 1 if kernel["EnableMatrixInstruction"] and kernel["MIInputPerThread"] > 1 and\ - kernel["SourceSwap"] and kernel["VgprForLocalReadPacking"]: + kernel["SourceSwap"] and kernel["VgprForLocalReadPacking"] and kernel["ClusterLocalRead"]: if (not kernel["UnrollMajorLDSA"]): self.lrvwTileA = min(kernel["MIInputPerThread"], kernel["VectorWidth"]) # should not exceed MIInputPerThread if (not kernel["UnrollMajorLDSB"]): @@ -4705,6 +4729,13 @@ def globalReadIncrementAB(self, kernel, loopIdx, prefetchIndex, incs=1): @abc.abstractmethod def globalReadDo(self, kernel, mode, tP, vregSetIdx=0): return "" + + ############################################################################## + # Global Read A/B completed + ############################################################################## + @abc.abstractmethod + def doneGlobalABReads(self, kernel): + return "" ############################################################################## # directToLds m0 update: Do It A/B diff --git a/Tensile/KernelWriterAssembly.py b/Tensile/KernelWriterAssembly.py index f12956e28..18fed4118 100644 --- a/Tensile/KernelWriterAssembly.py +++ b/Tensile/KernelWriterAssembly.py @@ -30,7 +30,7 @@ from .Utils import ceil_divide from .AsmMemoryInstruction import MemoryInstruction, getGlcBitName, getSlcBitName from .AsmRegisterPool import RegisterPool -from .AsmUtils import inst, vgpr, sgpr, accvgpr, log2, vectorStaticDivideAndRemainder, vectorStaticDivide, vectorStaticRemainder, scalarStaticDivideAndRemainder, vectorStaticMultiply, staticMultiply, scalarStaticMultiply +from .AsmUtils import inst, vgpr, sgpr, accvgpr, log2, vectorStaticDivideAndRemainder, vectorStaticDivide, vectorStaticRemainder, scalarStaticDivideAndRemainder, vectorStaticMultiply, staticMultiply, scalarStaticMultiply, instCommentOnly from math import ceil, trunc, modf, log from copy import deepcopy @@ -377,6 +377,8 @@ def strideRef(self, tc, dim): dim is index 0...max indices and is in global index space. """ problemType = self.kernel["ProblemType"] + if tc == 'WS': # WS used by Stream-K kernel + tc = 'D' if tc in ['A','B']: if not problemType["UseInitialStridesAB"] and \ dim == problemType["IndexAssignments%s"%tc][0]: @@ -547,6 +549,7 @@ def undefineSgpr(self, name): # later references will result in compile-time error (with odd 'error: expected relocatable expression') # and 'Kernel ... not found in any loaded module' # TODO: temporarily disable undef as it seems to have issues + # return ".set sgpr%s, UNDEF\n" % name return ".set %s, UNDEF\n" % name def defineVariableSgprs(self, kernel): @@ -562,6 +565,16 @@ def defineVariableSgprs(self, kernel): # product of all summation dimensions, this also will be divided if GSU is enabled self.defineSgpr("UnrollLoopLastIter", 1) + if kernel["StreamK"]: + # StreamK vars + self.defineSgpr("StreamKIdx", 1) + self.defineSgpr("StreamKIter", 1) + self.defineSgpr("StreamKIterEnd", 1) + self.defineSgpr("StreamKLocalStart", 1) + self.defineSgpr("StreamKLocalEnd", 1) + if kernel["StreamK"] == 2 or kernel["StreamK"] == 3: + self.defineSgpr("SrdWS", 4, 4) + if kernel["PackSummationDims"] and kernel["GlobalSplitU"]>1: self.defineSgpr("GsuNumIter%s"%self.loopChar(kernel,self.unrollIdx), 1) @@ -684,7 +697,7 @@ def defineVariableSgprs(self, kernel): % (self.sgprPool.size(), self.maxSgprs)) # TODO-persistent - likely recompute some of the registers above. - if kernel["PersistentKernel"]: + if kernel["PersistentKernel"] or kernel["StreamK"]: self.lastPostLoopSgpr = self.sgprPool.size() ############################################################################## @@ -780,7 +793,7 @@ def initKernel(self, kernel, tPA, tPB ): # - not BufferLoad: requires WorkGroup0,1 (means need to calculate WGM) #- groOffsetInMacroTile=0: requires WorkGroup0,1 (means need to calculate WGM) if ((kernel.enabledSplitLDS and (kernel["UnrollMajorLDSA"] or kernel["UnrollMajorLDSB"])) or \ - kernel["PersistentKernel"] or \ + kernel["PersistentKernel"] or kernel["StreamK"] or \ (not kernel["BufferLoad"]) or \ self.groOffsetInMacroTile == 0): # disable init opt for local write @@ -1029,7 +1042,7 @@ def initKernel(self, kernel, tPA, tPB ): # : int8x4-gemm (internal = i32) self.bpeCinternal = int(self.bpr * kernel["ProblemType"]["ComputeDataType"].numRegisters()) - if kernel["_GlobalAccumulation"]: + if kernel["_GlobalAccumulation"] and kernel["_GlobalAccumulation"] != 'PartialsBuffer': self.bpeCexternal = self.bpeCinternal # special case for wmma h and b @@ -1037,7 +1050,7 @@ def initKernel(self, kernel, tPA, tPB ): and globalParameters["AsmCaps"][self.version]["HasWMMA"] and (kernel["ProblemType"]["ComputeDataType"].numRegisters() == 0.5)): self.bpeCinternal = 4 - if kernel["_GlobalAccumulation"]: + if kernel["_GlobalAccumulation"]: # TODO SK and kernel["_GlobalAccumulation"] != 'PartialsBuffer': self.bpeCexternal = 2 #jgolds Need to check device for support @@ -1684,6 +1697,8 @@ def initKernel(self, kernel, tPA, tPB ): numSgprAddressC = self.rpga # til end numSgprAddressA = self.rpga # til read offsets numSgprAddressB = self.rpga # til read offsets + numSgprAddressWS = self.rpga + numSgprAddressFlags = self.rpga # would not less than 1 reg, # since even if ComputeType = H, we still pass the arg as a 32-bit (concate two 16-bit) numSgprAlpha = max(1,int(self.bpeCinternal/4)) @@ -1825,8 +1840,11 @@ def initKernel(self, kernel, tPA, tPB ): self.defineSgpr("AddressC", numSgprAddressC) self.defineSgpr("AddressA", numSgprAddressA) self.defineSgpr("AddressB", numSgprAddressB) - self.argOffsetOffset = self.argAddressOffset + (numSgprAddressD + numSgprAddressC + numSgprAddressA + numSgprAddressB) * 4 + if kernel["StreamK"] == 2 or kernel["StreamK"] == 3: + self.defineSgpr("AddressWS", numSgprAddressWS) + self.defineSgpr("AddressFlags", numSgprAddressFlags) + self.argOffsetOffset += (numSgprAddressWS + numSgprAddressFlags) * 4 self.defineSgpr("OffsetD", self.numSgprOffsetD) self.defineSgpr("OffsetC", self.numSgprOffsetC) @@ -1886,6 +1904,25 @@ def initKernel(self, kernel, tPA, tPB ): self.defineSgpr("MagicNumProblemNumGroupTiles0By1", 1) # for PKAB, use for Magic Div Alg 2 by (nwg0*nwg1) self.defineSgpr("MagicShiftProblemNumGroupTiles0By1", 1) # for PKAB, use for Magic Div Alg 2 by (nwg0*nwg1) pkArgumentToLoad += 3 + skArgumentToLoad = 0 + if kernel["StreamK"]: + # StreamK args + self.defineSgpr("MagicNumberProblemNumGroupTiles0", 1) # Magic number to use for division + self.defineSgpr("MagicShiftProblemNumGroupTiles0", 1) # Magic shift/abit to use for division alg 2 + self.defineSgpr("ItersPerTile", 1) + self.defineSgpr("MagicNumberItersPerTile", 1) + self.defineSgpr("MagicShiftItersPerTile", 1) + self.defineSgpr("TotalIters", 1) + self.defineSgpr("SKItersPerWG", 1) + skArgumentToLoad += 7 + if kernel["StreamK"] == 3: # Two-tile SK + self.defineSgpr("skGrid", 1) + self.defineSgpr("skTiles", 1) + self.defineSgpr("skExtraIters", 1) + # self.defineSgpr("dpTilesPerWG", 1) + skArgumentToLoad += 3 + + #------------------------ # Registers defined below this point are not available in the post-loop # Post-loop is after tail loop exits, ie the store code. @@ -1915,7 +1952,9 @@ def initKernel(self, kernel, tPA, tPB ): if kernel["LocalWriteUseSgprB"]: self.defineSgpr("LocalWriteAddrB", 1) - self.numSgprToLoad = 2 + 2 + numSgprAddressD + numSgprAddressC + numSgprAddressA + numSgprAddressB + numSgprAlpha + \ + self.numSgprToLoad = 2 + 2 + numSgprAddressD + numSgprAddressC + numSgprAddressA + numSgprAddressB + \ + ((numSgprAddressWS + numSgprAddressFlags) if kernel["StreamK"] >= 2 else 0) + \ + numSgprAlpha + \ (numSgprBeta if kernel["ProblemType"]["UseBeta"] else 0) + self.numSgprStridesD + self.numSgprStridesC + self.numSgprStridesA + \ self.numSgprStridesB + self.numSgprSizesFree + self.numSgprSizesSum + \ len(self.sumMagicParms)*2 + len(kernel["PackedC0IdxChars"][:-1])*2 + \ @@ -1923,6 +1962,7 @@ def initKernel(self, kernel, tPA, tPB ): 1 + \ 2 + \ pkArgumentToLoad + \ + skArgumentToLoad + \ 3 + \ self.numSgprOffsetD + self.numSgprOffsetC + self.numSgprOffsetA + self.numSgprOffsetB @@ -2737,7 +2777,12 @@ def functionSignature(self, kernel ): if kernel["BufferLoad"] or kernel["BufferStore"]: kStr += self.comment1("2GB limit - set offsets to -1 to exceed this and clamp") - kStr += self.macroRegister("BufferLimit", "0xffffffff") + # for A + limitValue = 0 if self.enable["InvalidGlobalReadA"] else 0xffffffff + kStr += self.macroRegister("BufferLimitA", hex(limitValue)) + # for B + limitValue = 0 if self.enable["InvalidGlobalReadB"] else 0xffffffff + kStr += self.macroRegister("BufferLimitB", hex(limitValue)) #TODO-64 : This is max 32-bit negative value, the tail loop # does incrementally step through the GRO and increment GRO # which are initialized with this value @@ -3283,6 +3328,10 @@ def allocateResources(self, kernel, lraCode=None): # however, in order to match sgpr to kernel argument memory, some unnecessarily sgpr will also be defined, and caused wasting of sgpr. # TODO: more efficient way is to organize both sgpr and kernel argument memory in API + # KernArgAddress needed for general batch after loading arguments + if kernel["ProblemType"]["StridedBatched"] or not kernel["ProblemType"]["Batched"]: + self.undefineSgpr("KernArgAddress") + if kernel.enabledSetPrioSplitLDS: kStr += inst("s_setprio", "1", "prioritize init code so as to issue load sooner") @@ -3314,7 +3363,7 @@ def allocateResources(self, kernel, lraCode=None): self.sgprAddressStrAB = "Address" self.releaseSgprAdressCD = False self.sgprAddressStrCD = "Address" - if (not kernel["PersistentKernel"]) and kernel["CheckDimOverflow"]<2 and kernel["ProblemType"]["OperationType"] == 'GEMM': + if (not kernel["PersistentKernel"] and not kernel["StreamK"]) and kernel["CheckDimOverflow"]<2 and kernel["ProblemType"]["OperationType"] == 'GEMM': # A,B check if kernel["BufferLoad"]: self.releaseSgprAdressAB = True @@ -3325,7 +3374,7 @@ def allocateResources(self, kernel, lraCode=None): self.sgprAddressStrCD = "Srd" # add offset to buffer - if not kernel["_GlobalAccumulation"]: + if not kernel["_GlobalAccumulation"] or kernel["_GlobalAccumulation"] == 'PartialsBuffer': kStr += inst("s_lshl_b64", sgpr("OffsetD", 2), sgpr("OffsetD", 2), hex(log2(self.bpeCexternal)), "elements offset to bytes offset") kStr += inst("s_add_u32", sgpr("%sD+0"%self.sgprAddressStrCD), sgpr("AddressD+0"), sgpr("OffsetD"), "add offset to buffer address") kStr += inst("s_addc_u32", sgpr("%sD+1"%self.sgprAddressStrCD), sgpr("AddressD+1"), sgpr("OffsetD+1"), "add offset to buffer address") @@ -3606,7 +3655,44 @@ def extractPackedCoord1ToRowStart(self, kernel, packedC1, packedCoordVgpr, store ############################################################################## def openPersistentLoop(self, kernel): kStr = "" - if kernel["PersistentKernel"]: + if kernel["PersistentKernel"] or kernel["StreamK"]: + if kernel["StreamK"]: + # Workload calculations + kStr += inst("s_mov_b32", sgpr("StreamKIdx"), sgpr("WorkGroup0"), "Save original StreamK index") + if kernel["StreamK"] < 3: # Basic SK + kStr += inst("s_mul_i32", sgpr("StreamKIter"), sgpr("StreamKIdx"), sgpr("SKItersPerWG"), "StreamK starting iteration") + kStr += inst("s_add_u32", sgpr("StreamKIterEnd"), sgpr("StreamKIter"), sgpr("SKItersPerWG"), "StreamK ending iteration") + kStr += inst("s_min_u32", sgpr("StreamKIterEnd"), sgpr("StreamKIterEnd"), sgpr("TotalIters"), "Cap ending iter at total iters") + kStr += inst("s_cmp_lt_u32", sgpr("StreamKIter"), sgpr("StreamKIterEnd"), "Make sure there's work to do") + kStr += self.longBranchScc0("label_%04u" % (self.getLabelNum("KernelEnd")), positiveOnly=True) + # kStr += inst("s_cbranch_scc0", "label_%04u" % (self.getLabelNum("KernelEnd")), "edge case that work doesn't divide well") + kStr += self.undefineSgpr("TotalIters") + elif kernel["StreamK"] == 3: # Two-tile SK + # iter count after all extra iters have been distributed + kStr += inst("s_mul_i32", sgpr("StreamKIter"), sgpr("StreamKIdx"), sgpr("SKItersPerWG"), "StreamK starting iteration (case: after extra iters)") + kStr += inst("s_add_u32", sgpr("StreamKIter"), sgpr("StreamKIter"), sgpr("skExtraIters"), "Add extra iters") + kStr += inst("s_add_u32", sgpr("StreamKIterEnd"), sgpr("StreamKIter"), sgpr("SKItersPerWG"), "StreamK ending iteration (case: after extra iters)") + # iter count before all extra iters have been distributed + # stmp+1 = SKItersPerWG + 1 extra iteration + sIter = self.sgprPool.checkOut(2, "SKIter", preventOverflow=0) + kStr += inst("s_add_u32", sgpr(sIter+1), sgpr("SKItersPerWG"), 1, "Spread out extra iterations") + kStr += inst("s_mul_i32", sgpr(sIter), sgpr("StreamKIdx"), sgpr(sIter+1), "StreamK starting iteration (case: before extra iters)") + kStr += inst("s_add_u32", sgpr(sIter+1), sgpr(sIter), sgpr(sIter+1), "StreamK ending iteration (case: before extra iters)") + # select correct start/end iteration index + kStr += inst("s_cmp_lt_u32", sgpr("StreamKIdx"), sgpr("skExtraIters"), "Check if lane gets an extra iteration") + kStr += inst("s_cselect_b32", sgpr("StreamKIter"), sgpr(sIter), sgpr("StreamKIter"), "Set start iter") + kStr += inst("s_cselect_b32", sgpr("StreamKIterEnd"), sgpr(sIter+1), sgpr("StreamKIterEnd"), "Set end iter") + self.sgprPool.checkIn(sIter) + # clamp to end of sk iterations + # TODO maybe remove clamp, since extra iters code should guarantee total iterations match + stmp = self.sgprPool.checkOut(1, "TotalSKIters", preventOverflow=0) + kStr += inst("s_mul_i32", sgpr(stmp), sgpr("skTiles"), sgpr("ItersPerTile"), "Total SK iters") + kStr += inst("s_min_u32", sgpr("StreamKIterEnd"), sgpr("StreamKIterEnd"), sgpr(stmp), "Cap ending iter at total SK iters") + self.sgprPool.checkIn(stmp) + # check if this WG has no work to do + kStr += inst("s_cmp_lt_u32", sgpr("StreamKIter"), sgpr("TotalIters"), "Make sure there's work to do") + kStr += self.longBranchScc0("label_%04u" % (self.getLabelNum("KernelEnd")), positiveOnly=True) + kStr += self.comment3("Persistent Loop Start") kStr += self.getLabelDef("PersistentLoopStart") # kStr += inst("s_add_u32", sgpr("PersistentLoopIter"), sgpr("PersistentLoopIter"), hex(1), "Inc PersistentLoop Iter") # Back-up: not needed now @@ -3620,9 +3706,69 @@ def openPersistentLoop(self, kernel): def graWorkGroup(self, kernel, isPap): kStr = "" + if kernel["StreamK"]: + # StreamK workgroup mapping + stmp = self.sgprPool.checkOut(4, "SKMappingTemp", preventOverflow=0) + # Always reset pointers to handle odd-exit case which moves LRO to the upper bank + if not self.prefetchAcrossPersistent and kernel["PrefetchGlobalRead"]: + kStr += self.localReadResetOffsets(kernel, self.tPA) + kStr += self.localReadResetOffsets(kernel, self.tPB) + + kStr += self.comment1("StreamK calculate tile idx and map to WG") + + # stmp = tile index + kStr += self.sMagicDivAlg2(kernel, stmp, sgpr("StreamKIter"), sgpr("MagicNumberItersPerTile"), sgpr("MagicShiftItersPerTile")) + # stmp+1 = tile start + kStr += inst("s_mul_i32", sgpr(stmp+1), sgpr(stmp), sgpr("ItersPerTile"), "Tile start iteration") + # stmp+2 = tile end + kStr += inst("s_add_u32", sgpr(stmp+2), sgpr(stmp+1), sgpr("ItersPerTile"), "Tile end iteration") + # local start + kStr += inst("s_sub_u32", sgpr("StreamKLocalStart"), sgpr("StreamKIter"), sgpr(stmp+1), "Local iteration start") + # local end (SK tile) + kStr += inst("s_min_u32", sgpr("StreamKLocalEnd"), sgpr("StreamKIterEnd"), sgpr(stmp+2), "1. (Local) iteration end (SK tile)") + kStr += inst("s_sub_u32", sgpr("StreamKLocalEnd"), sgpr("StreamKLocalEnd"), sgpr(stmp+1), "2. Local iteration end (SK tile)") + + if kernel["StreamK"] == 3: # Two-tile algorithm + # local end (DP tile) + kStr += inst("s_sub_u32", sgpr(stmp+3), sgpr(stmp+2), sgpr(stmp+1), "Local iteration end (DP tile)") + # select correct local end + kStr += inst("s_cmp_lt_u32", sgpr("StreamKIter"), sgpr("StreamKIterEnd"), "Check if in SK or DP section") + kStr += inst("s_cselect_b32", sgpr("StreamKLocalEnd"), sgpr("StreamKLocalEnd"), sgpr(stmp+3), "Apply SK or DP end iteration") + + # Increment StreamK iteration + # If moving from SK to DP, next iteration is first DP + # stmp = offset to first DP tile + kStr += inst("s_mul_i32", sgpr(stmp+3), sgpr("skTiles"), sgpr("ItersPerTile"), "Offset to first DP tile") + kStr += inst("s_mul_i32", sgpr(stmp+1), sgpr("StreamKIdx"), sgpr("ItersPerTile"), "WG tile offset") + kStr += inst("s_add_u32", sgpr(stmp+3), sgpr(stmp+3), sgpr(stmp+1), "DP start offset + WG offset") + # If already in DP, add dpShift + kStr += inst("s_mul_i32", sgpr(stmp+1), sgpr("skGrid"), sgpr("ItersPerTile"), "DP iterations shift") + kStr += inst("s_add_u32", sgpr(stmp+1), sgpr(stmp+1), sgpr("StreamKIter"), "Add DP shift") + # Save DP iter in stmp + kStr += inst("s_cmp_lt_u32", sgpr("StreamKIter"), sgpr("StreamKIterEnd"), "Check if in SK or DP section") + kStr += inst("s_cselect_b32", sgpr(stmp+3), sgpr(stmp+3), sgpr(stmp+1), "Select first DP tile, or add DP shift") + # If staying in SK portion, next iteration is stmp+2 + kStr += inst("s_cmp_lt_u32", sgpr(stmp+2), sgpr("StreamKIterEnd"), "Check if there are more SK tiles") + kStr += inst("s_cselect_b32", sgpr("StreamKIter"), sgpr(stmp+2), sgpr(stmp+3), "Select next SK or DP tile") + + else: # Basic Stream-K + # Increment StreamK iteration + kStr += inst("s_mov_b32", sgpr("StreamKIter"), sgpr(stmp+2), "Increment StreamK Iteration") + + # Map StreamK tile index to wg0/1 + kStr += self.comment1("Map StreamK tile index to wg0/1") + kStr += self.sMagicDivAlg2(kernel, stmp+1, sgpr(stmp), sgpr("MagicNumberProblemNumGroupTiles0"), sgpr("MagicShiftProblemNumGroupTiles0")) + kStr += inst("s_mov_b32", sgpr("WorkGroup1"), sgpr(stmp+1), "wg1 = Tile Idx / problemNumGroupTiles0") + kStr += inst("s_mul_i32", sgpr("WorkGroup0"), sgpr(stmp+1), sgpr("NumWorkGroups0"), "remainder part 1 : quotient * divisor") + kStr += inst("s_sub_u32", sgpr("WorkGroup0"), sgpr(stmp), sgpr("WorkGroup0"), "wg0 = Tile Idx % problemNumGroupTiles0") + + kStr += "\n" + + self.sgprPool.checkIn(stmp) + # del stmpRef + if kernel["PersistentKernel"]: - stmpRef = self.getTmpSgpr(8, 4) - stmp = stmpRef.idx() + stmp = self.sgprPool.checkOutAligned(8, 4, "PKMappingTemp", preventOverflow=0) # Always reset pointers to handle odd-exit case which moves LRO to the upper bank if not self.prefetchAcrossPersistent and kernel["PrefetchGlobalRead"]: kStr += self.localReadResetOffsets(kernel, self.tPA) @@ -3691,6 +3837,7 @@ def graWorkGroup(self, kernel, isPap): #kStr += self.assert_ne(sgpr("SerialWorkGroupIter"), 2) kStr += "\n" + self.sgprPool.checkIn(stmp) kStr += self.comment1("graWorkGroup mapping") if kernel["GlobalSplitU"] > 1: @@ -3702,7 +3849,7 @@ def graWorkGroup(self, kernel, isPap): nwg1 = self.vgprPool.checkOut(1, "nwg1", self.preventVgprOverflowDuringNewTile) quotient = self.vgprPool.checkOut(1, "quotient", self.preventVgprOverflowDuringNewTile) tmpVgpr = self.vgprPool.checkOut(1, "tmpVgpr", self.preventVgprOverflowDuringNewTile) - tmpSgpr = self.getTmpSgpr(1).idx() + tmpSgpr = self.sgprPool.checkOut(1, "GSUMappingTemp", preventOverflow=0) kStr += "// GSU-WGMapRR :nwg1 = (size%s + MT%s - 1) / MT%s;%s" \ % (self.tileChar1, self.tileChar1, self.tileChar1, self.endLine) kStr += inst("v_mov_b32", vgpr(nwg1), sgpr("SizesFree+1"), "") @@ -3739,13 +3886,14 @@ def graWorkGroup(self, kernel, isPap): self.vgprPool.checkIn(wg1) self.vgprPool.checkIn(quotient) self.vgprPool.checkIn(remainder) + self.sgprPool.checkIn(tmpSgpr) else: kStr += "// GSU-not-WGMapRR :nwg1 = (size%s + MT%s - 1) / MT%s;%s" \ % (self.tileChar1, self.tileChar1, self.tileChar1, self.endLine) # gsuSumIdx = wg1 % GSU # wg1 = wg1 / GSU - tmpSgpr = self.getTmpSgpr(3).idx() # needs 3 + tmpSgpr = self.sgprPool.checkOutAligned(3, 2, "GSUMappingTemp", preventOverflow=0) divisor = tmpSgpr+2 kStr += inst("s_mov_b32", sgpr(divisor), sgpr("WorkGroup1"), \ "copying for divisor") @@ -3764,6 +3912,7 @@ def graWorkGroup(self, kernel, isPap): #kStr += dump(vgpr(tmp)) # remainder #self.vgprPool.checkIn(tmp) #kStr += "s_endpgm\n" + self.sgprPool.checkIn(tmpSgpr) ######################################## # Blocked rows or columns @@ -3772,7 +3921,8 @@ def graWorkGroup(self, kernel, isPap): smallNumMagicShift = 31 magicNumberWgm = ((1< 1: + tmpVgpr = self.vgprPool.checkOut(1,"tmpVgpr") + # generate the code only when num1DBlocks > 1. + # if num1DBlocks is 1, % num1DBlocks is always 0 and no difference in rReg value + kStr += vectorStaticDivide(tmpVgpr, qReg, dividedForBlkId, tmpSgpr, \ + "2. block offset: bnIdx = wtid / dividedForBlkId(%u)" % dividedForBlkId) + kStr += vectorStaticRemainder(tmpVgpr, tmpVgpr, num1DBlocks, tmpSgpr, \ + "2. block offset: bnIdx = bnIdx %% num1DBlocks(%u)" % num1DBlocks) # assuming num1DBlocks is power of 2 to use same vreg for src and dst + kStr += staticMultiply(vgpr(tmpVgpr), vgpr(tmpVgpr), strideBlock, sgpr(tmpSgpr), \ + "2. block offset: bnOffset = bnIdx * strideBlock(%u)" % strideBlock) + kStr += inst("_v_add_u32", vgpr(rReg), vgpr(tmpVgpr), vgpr(rReg), \ + "3. add N and block offset: bnOffset = block and N offset") + self.vgprPool.checkIn(tmpVgpr) + else: + # comment only because bnIdx = bnIdx % num1DBlocks(1) = 0 + kStr += instCommentOnly("2. block offset: bnIdx = bnIdx %% num1DBlocks(%u) is 0. do nothing" % num1DBlocks) # unroll offset # need division for qReg kStr += vectorStaticDivide(qReg, qReg, dividendForKId, tmpSgpr, \ @@ -4685,6 +4858,16 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe, isPap): kStr += self.s_mul_u64_u32(sgpr(tileStart), sgpr(tileStart+1), sgpr(tileStart+0), \ strideF, "tlu=0, scaled tile-offset by stride") + if kernel["StreamK"]: + # StreamK partial tile - offset to tile start index + kStr += inst("s_mul_i32", sgpr(stmp), sgpr("StreamKLocalStart"), "DepthU", "StreamK tile start offset") + strideL = self.strideRef(tc, kernel["ProblemType"]["IndicesSummation"][0]) + kStr += self.s_mul_u64_u32(sgpr(stmp), sgpr(stmp+1), sgpr(stmp), strideL, "StreamK tile start offset") + if kernel["CheckDimOverflow"] >=2: + kStr += self.assert_eq(sgpr(stmp+1),0) + kStr += inst("s_add_u32", sgpr(tileStart+0), sgpr(tileStart+0), sgpr(stmp+0), "accum GsuOffset term to tilestart") + kStr += inst("s_addc_u32", sgpr(tileStart+1), sgpr(tileStart+1), sgpr(stmp+1), "accum GsuOffset term to tilestart") + if kernel["GlobalSplitU"] > 1: # Only GlobalSplitUSummationAssignmentRoundRobin supported for groOffsetInMacroTile - would need different math here for start: assert(kernel["GlobalSplitUSummationAssignmentRoundRobin"]) @@ -4760,7 +4943,7 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe, isPap): kStr += inst("s_addc_u32", sgpr("ShadowLimit%s+1"%tc), sgpr("ShadowLimit%s+1"%tc), 0, "extend limit for directToLDS instruction offset") kStr += inst("s_cmp_eq_u32", sgpr("ShadowLimit%s+1"%tc), 0, "are we within 2^32?") - kStr += inst("s_cselect_b32", sgpr("Srd%s+2"%tc), sgpr("ShadowLimit%s+0"%tc), "BufferLimit", "Move shadow to real if we are within 2^32") + kStr += inst("s_cselect_b32", sgpr("Srd%s+2"%tc), sgpr("ShadowLimit%s+0"%tc), "BufferLimit%s"%tc, "Move shadow to real if we are within 2^32") else: # put limit directly into SRD: kStr += inst("s_lshl_b32", sgpr("Srd%s+2"%tc), sgpr(stmp+0), hex(log2(tP["bpe"])), "Set limit to use bytes") @@ -4789,14 +4972,17 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe, isPap): kStr += inst("s_addc_u32", sgpr(tileStart+1), sgpr(tileStart+1), sgpr(stmp+1), "accum wg term to tilestart") wg+=1 + sgprAddress0 = sgpr("%s%s+0"%(self.sgprAddressStrAB,tc)) + sgprAddress1 = sgpr("%s%s+1"%(self.sgprAddressStrAB,tc)) + # Add the tile start to the SRD if wroteTileStart: kStr += scalarStaticMultiply(sgpr(tileStart,2), sgpr(tileStart,2), bpe, None, "tileStart *= BPE") - kStr += inst("s_add_u32", sgpr("Srd%s+0"%tc), sgpr("%s%s+0"%(self.sgprAddressStrAB,tc)), sgpr(tileStart+0), "SRD base = Address+ tileStart0") - kStr += inst("s_addc_u32", sgpr("Srd%s+1"%tc), sgpr("%s%s+1"%(self.sgprAddressStrAB,tc)), sgpr(tileStart+1), "SRD base = Address+ tileStart1") + kStr += inst("s_add_u32", sgpr("Srd%s+0"%tc), sgprAddress0, sgpr(tileStart+0), "SRD base = Address+ tileStart0") + kStr += inst("s_addc_u32", sgpr("Srd%s+1"%tc), sgprAddress1, sgpr(tileStart+1), "SRD base = Address+ tileStart1") else: - kStr += inst("s_mov_b32", sgpr("Srd%s+0"%tc), sgpr("%s%s+0"%(self.sgprAddressStrAB,tc)), "init SRD base address (lower )" ) - kStr += inst("s_mov_b32", sgpr("Srd%s+1"%tc), sgpr("%s%s+1"%(self.sgprAddressStrAB,tc)), "init SRD base address (upper) + other fields" ) + kStr += inst("s_mov_b32", sgpr("Srd%s+0"%tc), sgprAddress0, "init SRD base address (lower )" ) + kStr += inst("s_mov_b32", sgpr("Srd%s+1"%tc), sgprAddress1, "init SRD base address (upper) + other fields" ) # self.groOffsetInMacroTile == 1 case, pre-pad is already subtracted from AddressA/B if prePad and self.groOffsetInMacroTile == 0: @@ -4820,8 +5006,8 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe, isPap): # - subtract the SRD base and SRD buffer limit # - Make sure the 64bit result is >0 kStr += inst("s_lshl_b64", sgpr(stmp,2), sgpr("Tensor2dSize%s"%tc,2), log2(bpe), "tensor size in bytes") - kStr += inst("s_add_u32", sgpr(stmp+0), sgpr(stmp+0), sgpr("Address%s+0"%tc), "add start ptr to compute tensor%s bot-right"%tc) - kStr += inst("s_addc_u32", sgpr(stmp+1), sgpr(stmp+1), sgpr("Address%s+1"%tc), "add start ptr to compute tensor%s bot-right"%tc) + kStr += inst("s_add_u32", sgpr(stmp+0), sgpr(stmp+0), sgprAddress0, "add start ptr to compute tensor%s bot-right"%tc) + kStr += inst("s_addc_u32", sgpr(stmp+1), sgpr(stmp+1), sgprAddress1, "add start ptr to compute tensor%s bot-right"%tc) kStr += inst("s_sub_u32", sgpr(stmp+0), sgpr(stmp+0), sgpr("Srd%s+0"%tc), "sub SRD base") kStr += inst("s_subb_u32", sgpr(stmp+1), sgpr(stmp+1), sgpr("Srd%s+1"%tc), "sub SRD base") if self.use64bShadowLimit: @@ -4849,6 +5035,11 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe, isPap): else: kStr += inst("s_mov_b32", sgpr("InitialSrd%sLimit"%tc), sgpr("Srd%s+2"%tc), "save limit") + # invalid global read for performance evaluation only + if self.enable["InvalidGlobalRead%s"%tc]: + kStr += inst("s_mov_b32", sgpr("Srd%s+2"%tc), hex(0), "set out-of-bound addr for performance evaluation only") + kStr += inst("s_mov_b32", sgpr("ShadowLimit%s+1"%tc), hex(0xffffffff), "set out-of-bound addr for performance evaluation only") + return kStr ############################################################################## @@ -5364,6 +5555,11 @@ def lwaFirstOffset(self, kernel, tP, uDu=0): #if tP["isA"]: #kStr += self.dump(vgpr("LocalWriteAddr%s"%tP["tensorChar"])) #kStr += self.bomb(-40) + + # invalid local write for performance evaluation only + if self.enable["InvalidLocalWrite%s"%tc]: + kStr += inst("v_mov_b32", vgpr(destVgpr), self.LdsOOB, "set out-of-bound addr for performance evaluation only") + # do not generate local write address code if DirectToVgpr is enabled return "" if self.dontAppendCode or kernel["DirectToVgpr%s"%tc] else kStr @@ -5489,6 +5685,10 @@ def lraFinalOffset(self, kernel, tP): "Final Offset: add padding %u per block %u" % (kernel["LdsPad%s"%tc], kernel["LdsBlockSizePerPad%s"%tc])) self.vgprPool.checkIn(rReg) + # invalid local read for performance evaluation only + if self.enable["InvalidLocalRead%s"%tc]: + kStr += inst("v_mov_b32", finalVgpr, self.LdsOOB, "set out-of-bound addr for performance evaluation only") + return kStr ############################################################################## @@ -5785,7 +5985,8 @@ def declareLoopNumIter(self, kernel): ############################################################################## def declareStaggerParms(self, kernel): kStr="" - tmpSgpr = self.getTmpSgpr(2).idx() + tmpSgprRef = self.getTmpSgpr(2) + tmpSgpr = tmpSgprRef.idx() if self.staggerU: # this could be dynamic? if kernel["StaggerUMapping"] == 0: @@ -5814,6 +6015,14 @@ def declareStaggerParms(self, kernel): if kernel["_staggerStrideShift"] > 0: # generate code only when it is necessary kStr += inst("s_lshl_b32", sgpr("StaggerUIter"), sgpr("StaggerUIter"), \ kernel["_staggerStrideShift"], "shift by StaggerUStride") + + if kernel["StreamK"]: + # Set stagger=0 for partial tiles to avoid using stagger larger than workload + kStr += inst("s_cmp_gt_u32", sgpr("StreamKLocalStart"), 0, "does wg start tile?") + kStr += inst("s_cmov_b32", sgpr("StaggerUIter"), 0, "set stagger=0 for partial tiles") + kStr += inst("s_cmp_lt_u32", sgpr("StreamKLocalEnd"), sgpr("ItersPerTile"), "does wg finish tile?") + kStr += inst("s_cmov_b32", sgpr("StaggerUIter"), 0, "set stagger=0 for partial tiles") + return kStr ############################################################################## @@ -5826,7 +6035,8 @@ def calculateStagger(self, kernel, tP): if self.staggerU: assert (kernel["BufferLoad"]) - staggerTmp = self.getTmpSgpr(2).idx() + staggerTmpRef = self.getTmpSgpr(2) + staggerTmp = staggerTmpRef.idx() #--- imod.addComment1("SRDs += (StaggerUIter) * GlobalReadIncs%s+%u"% (tc, self.unrollIdx)) @@ -5891,7 +6101,8 @@ def removeStagger(self, kernel, tP): imod = Code.Module("removeStagger") if self.staggerU: tc = tP["tensorChar"] - tmp = self.getTmpSgpr(4).idx() + tmpRef = self.getTmpSgpr(4) + tmp = tmpRef.idx() tmpForInc = tmp tmpForExtra = tmp + 2 # need to use extra 64bit mul to avoid negative value by subtraction @@ -6023,11 +6234,17 @@ def calculateLoopNumIter(self, kernel, loopIdx, isPap): dividend = loopCounter kStr += scalarStaticDivideAndRemainder( loopCounterName, None, dividend, kernel["LocalSplitU"], tmpSgpr, 0) + # skip tail loop if StreamK WG not processing final iteration + if kernel["StreamK"]: + # Check if tile finished + kStr += inst("s_cmp_lt_u32", sgpr("StreamKLocalEnd"), sgpr("ItersPerTile"), "Check if WG processes final iteration of tile") + kStr += inst("s_cmov_b32", loopCounter, hex(0), "This WG not completing tile") + # if GSU numIter=0 if gsuSumIdx != remainder if kernel["GlobalSplitU"] > 1: kStr += inst("s_cmp_lg_u32", sgpr("GSUSumIdx"), sgpr("GSUSumIdx+1"), \ "gsuSumIdx == numIterPerWgRemainder" ) - kStr += inst("s_cmov_b32", loopCounter, hex(0), "numIter=0 if gsuSimIdx!=remainder") + kStr += inst("s_cmov_b32", loopCounter, hex(0), "numIter=0 if gsuSumIdx!=remainder") # do not use early exit here in tailLoop in NLL case if not self.tailLoopInNLL: @@ -6062,7 +6279,19 @@ def calculateLoopNumIter(self, kernel, loopIdx, isPap): divisor = kernel["DepthU"] asem = kernel["AssertSummationElementMultiple"] gsu = kernel["GlobalSplitU"] - if self.noTailLoop and ((asem % gsu != 0) or ((asem//gsu) % kernel["DepthU"] != 0)): + if kernel["StreamK"]: + # Use StreamK params for loop count + kStr += inst("s_sub_u32", sgpr(loopCounterName), sgpr("StreamKLocalEnd"), sgpr("StreamKLocalStart"), "StreamK loop counter = localEnd - localStart") + # Adjust loop count for tail loop + if not self.noTailLoop: + loopChar = self.indexChars[kernel["ProblemType"]["IndicesSummation"][self.unrollIdx]] + kStr += scalarStaticDivideAndRemainder(tmpSgpr, tmpSgpr+1, "SizesSum+%u"%self.unrollIdx, kernel["DepthU"], tmpSgpr+2, 2) + kStr += inst("s_cmp_eq_u32", sgpr(tmpSgpr+1), hex(0), "numIter%s == 0"%loopChar ) + kStr += inst("s_cselect_b32", sgpr(tmpSgpr), 0, 1, "check if size uses tail loop") + kStr += inst("s_cmp_eq_u32", sgpr("StreamKLocalEnd"), sgpr("ItersPerTile"), "Check if WG processes final iteration of tile") + kStr += inst("s_cselect_b32", sgpr(tmpSgpr), sgpr(tmpSgpr), 0, "this WG runs tail loop") + kStr += inst("s_sub_u32", sgpr(loopCounterName), sgpr(loopCounterName), sgpr(tmpSgpr), "Adjust loop counter for tail loop") + elif self.noTailLoop and ((asem % gsu != 0) or ((asem//gsu) % kernel["DepthU"] != 0)): # round up SizesSum/DepthU for noTailLoop case kStr += inst("s_add_i32", sgpr(quotient), (divisor - 1), sgpr(dividend), \ "round up SizeSum / DepthU" ) @@ -6492,19 +6721,28 @@ def closeLoop(self, kernel, loopIdx, finalLoop, loopCopies, uDu=None, emitEndLab # in this case, odd or/and even code is generated and use odd/even exit to avoid skipping odd/even code # (end label is generated after odd/even code) jumpLabel = loopLabelEndOddExit if oddLabel else loopLabelEndEvenExit + + # tail + SplitLds branch code + kStrSLDS = "" + if tailLoop and kernel.enabledSplitLDS: + tailLoopLabelEnd = self.getNamedLabel( + "TailLoopEnd%s%s"%(loopChar, "_G2L%s"%(kernel["DepthULdsDivisor"]-1) if kernel.enabledSplitLDS else "") ) + kStrSLDS += inst("s_cbranch_scc1", tailLoopLabelEnd, "break Loop%s"%loopChar) + thresForNextSubLoop = (uDu+1)*(kernel["_DepthULds"]) + kStrSLDS += inst("s_cmp_ge_u32", sgpr("OrigLoopCounter"), thresForNextSubLoop, + "OrigLoopCounter >= %u (G2L buffer %u/%u)"%(thresForNextSubLoop, uDu, kernel["DepthULdsDivisor"]) ) + if not finalLoop: - if jumpNeeded: + if kStrSLDS != "": + # tail + SplitLds branch case + kStr += kStrSLDS + elif jumpNeeded: # just an exit check, else fall through to the next loop copy kStr += inst("s_cbranch_scc1 %s"%(jumpLabel), "exit Loop%s"%loopChar ) else: #finalLoop: - if tailLoop and kernel.enabledSplitLDS: - tailLoopLabelEnd = self.getNamedLabel( - "TailLoopEnd%s%s"%(loopChar, "_G2L%s"%(kernel["DepthULdsDivisor"]-1) if kernel.enabledSplitLDS else "") ) - kStr += inst("s_cbranch_scc1", tailLoopLabelEnd, "break Loop%s"%loopChar) - thresForNextSubLoop = (uDu+1)*(kernel["_DepthULds"]) - kStr += inst("s_cmp_ge_u32", sgpr("OrigLoopCounter"), thresForNextSubLoop, - "OrigLoopCounter >= %u (G2L buffer %u/%u)"%(thresForNextSubLoop, uDu, kernel["DepthULdsDivisor"]) ) + # add tail + SplitLds branch case code (if exists) + kStr += kStrSLDS if jumpNeeded: kStr += inst("%s %s"%(finalJump, loopLabelBegin), \ @@ -6572,7 +6810,7 @@ def closeLoop(self, kernel, loopIdx, finalLoop, loopCopies, uDu=None, emitEndLab kStr += "%s:%s" % (loopLabelEnd, self.endLine) if tailLoop and not self.tailLoopInNLL: - if kernel["PersistentKernel"] or len(kernel["ProblemType"]["IndicesSummation"]) > 1: + if kernel["PersistentKernel"] or kernel["StreamK"] or len(kernel["ProblemType"]["IndicesSummation"]) > 1: # recover the 'damage' done to LRO: stmp = self.getTmpSgpr(1).idx() @@ -6652,6 +6890,7 @@ def endSummation(self, kernel, label = None, isOptNLL = False): kStr += self.comment1("endSummation: add vgpr [%u...%u) to pool" % \ (vbegin, vbegin+vsize)) + # TODO this undef limits ability to define sgprs closer to the time it is used lastRegTag=None for i in range(self.lastPostLoopSgpr, self.sgprPool.size()): regTag = self.sgprPool.pool[i].tag @@ -6700,9 +6939,8 @@ def endSummation(self, kernel, label = None, isOptNLL = False): ############################################################################## # src A,B str for MFMA ############################################################################## - def generateSrcStrForMFMA(self, kernel, tP, innerUnroll, vregSetIdx, vgprPerInput, u, iui, idxAB, bk=None): + def generateSrcStrForMFMA(self, kernel, tP, innerUnroll, vregSetIdx, vgprPerInput, m, u, iui, idxAB, bk=None): tc = tP["tensorChar"] - m = (u) % (self.numVgprBuffer+1) # local to use for MACs numVgprValuPerBlock = kernel["MIWaveTile%c"%tc] * kernel["MIInputPerThread"] * tP["bpe"] // self.bpr numIterPerCoalescedRead = self.numIterPerCoalescedReadA if tP["isA"] else self.numIterPerCoalescedReadB @@ -6734,7 +6972,7 @@ def generateSrcStrForMFMA(self, kernel, tP, innerUnroll, vregSetIdx, vgprPerInpu def mfmaIter(self, kernel, u, innerUnroll, vregSetIdx, lastKinloop=False, tail=False, firstIter=False): imod = Code.Module("mi") shiftK = Code.Module("shiftK") - m = (u) % (self.numVgprBuffer+1) # local to use for MACs + m = ((u) % (self.numVgprBuffer+1)) % kernel["LoopIters"] # local to use for MACs miInputType = kernel["ProblemType"]["F32XdlMathOp"] if kernel["EnableF32XdlMathOp"] else kernel["ProblemType"]["DataType"] @@ -6817,13 +7055,13 @@ def mfmaIter(self, kernel, u, innerUnroll, vregSetIdx, lastKinloop=False, tail=F if needKMaskForA: for a in range(0, kernel["MIWaveTileA"]): for iui in range(0, innerUnroll): - aStr_base = self.generateSrcStrForMFMA(kernel, self.tPA, innerUnroll, vregSetIdx, vgprPerInput, u, iui, a, bk) + aStr_base = self.generateSrcStrForMFMA(kernel, self.tPA, innerUnroll, vregSetIdx, vgprPerInput, m, u, iui, a, bk) aStr = vgpr(aStr_base, 1) shiftK.addCode(inst("v_cndmask_b32", aStr, aStr, hex(0), sgpr(tmpSgpr, 2), "set 0 if K_idx >= sizeL")) if needKMaskForB: for b in range(0, kernel["MIWaveTileB"]): for iui in range(0, innerUnroll): - bStr_base = self.generateSrcStrForMFMA(kernel, self.tPB, innerUnroll, vregSetIdx, vgprPerInput, u, iui, b, bk) + bStr_base = self.generateSrcStrForMFMA(kernel, self.tPB, innerUnroll, vregSetIdx, vgprPerInput, m, u, iui, b, bk) bStr = vgpr(bStr_base, 1) shiftK.addCode(inst("v_cndmask_b32", bStr, bStr, hex(0), sgpr(tmpSgpr, 2), "set 0 if K_idx >= sizeL")) @@ -6864,14 +7102,14 @@ def mfmaIter(self, kernel, u, innerUnroll, vregSetIdx, lastKinloop=False, tail=F for a in range(0, kernel["MIWaveTileA"]): for iui in range(0, innerUnroll): for bk in range(0, vgprPerInput): - aStr_base = self.generateSrcStrForMFMA(kernel, self.tPA, innerUnroll, vregSetIdx, vgprPerInput, u, iui, a, bk) + aStr_base = self.generateSrcStrForMFMA(kernel, self.tPA, innerUnroll, vregSetIdx, vgprPerInput, m, u, iui, a, bk) aStr = vgpr(aStr_base, 1) shiftK.addCode(inst("v_and_b32", aStr, aStr, vgpr(abReg+bk), "")) if needKMaskForB: for b in range(0, kernel["MIWaveTileB"]): for iui in range(0, innerUnroll): for bk in range(0, vgprPerInput): - bStr_base = self.generateSrcStrForMFMA(kernel, self.tPB, innerUnroll, vregSetIdx, vgprPerInput, u, iui, b, bk) + bStr_base = self.generateSrcStrForMFMA(kernel, self.tPB, innerUnroll, vregSetIdx, vgprPerInput, m, u, iui, b, bk) bStr = vgpr(bStr_base, 1) shiftK.addCode(inst("v_and_b32", bStr, bStr, vgpr(abReg+bk), "")) # release register @@ -6956,8 +7194,8 @@ def mfmaIter(self, kernel, u, innerUnroll, vregSetIdx, lastKinloop=False, tail=F accEnd = accStart + accs_per_wave - 1 idxA = idx0 if self.tPB["tile01Idx"] else idx1 idxB = idx1 if self.tPB["tile01Idx"] else idx0 - aStr_base = self.generateSrcStrForMFMA(kernel, self.tPA, innerUnroll, vregSetIdx, vgprPerInput, u, iui, idxA) - bStr_base = self.generateSrcStrForMFMA(kernel, self.tPB, innerUnroll, vregSetIdx, vgprPerInput, u, iui, idxB) + aStr_base = self.generateSrcStrForMFMA(kernel, self.tPA, innerUnroll, vregSetIdx, vgprPerInput, m, u, iui, idxA) + bStr_base = self.generateSrcStrForMFMA(kernel, self.tPB, innerUnroll, vregSetIdx, vgprPerInput, m, u, iui, idxB) aStr = vgpr(aStr_base, vgprPerInput) bStr = vgpr(bStr_base, vgprPerInput) Str0 = aStr if self.tPB["tile01Idx"] else bStr @@ -7246,7 +7484,7 @@ def openSumAtLeastUnroll(self, kernel, prefetch, isOptNLL, isPap): # skip beta check for StoreCInUnroll in OptNLL case if not kernel["StoreCInUnroll"]: - kStr += self.checkIsBetaZero(kernel, tmpSgpr, skipOptNLL) + kStr += self.checkIsBetaZero(kernel, skipOptNLL, tmpSgpr) # check alpha # skip alpha check for StoreCInUnroll in OptNLL case @@ -7255,7 +7493,7 @@ def openSumAtLeastUnroll(self, kernel, prefetch, isOptNLL, isPap): if kernel["ProblemType"]["ComputeDataType"].isHalf(): if kernel["ProblemType"]["HighPrecisionAccumulate"] and \ - kernel["PersistentKernel"]: + (kernel["PersistentKernel"] or kernel["StreamK"]): kStr += inst("s_cmp_eq_u32", sgpr("Alpha"), "1.0", "Alpha == 1.0 ?") # Otherwise, Alpha is a packed F16 so far (if Non-PK, the cvt is done later in GW) else: @@ -7299,7 +7537,7 @@ def openSumAtLeastUnroll(self, kernel, prefetch, isOptNLL, isPap): kStr += inst("s_cbranch_scc0 %s"%skipOptNLL, "branch if alpha != 1") kStr += "\n" - kStr += self.checkIsEdge(kernel, tmpSgpr, skipOptNLL) + kStr += self.checkIsEdge(kernel, skipOptNLL, tmpSgpr) kStr += "\n" # Check tail loop required: @@ -7445,7 +7683,8 @@ def closeSumAtLeastUnroll(self, kernel, prefetch, isOptNLL, isPap, isNGLL): oldSize = self.savedSgprPool.size() newSize = self.sgprPool.size() if newSize > self.savedSgprPool.size(): - for i in range(oldSize-1,newSize): + # fixed range to prevent overflowing resources in some cases + for i in range(oldSize,newSize): self.savedSgprPool.pool.append(self.savedSgprPool.Register(RegisterPool.Status.Available,"restore sgprPool")) self.sgprPool = self.savedSgprPool # restore vgprPool before alternate path self.savedSgprPool = None @@ -7485,7 +7724,7 @@ def incrementSrd(self, kernel, tP, incLower, incUpper, checkShadowLimitCopy=True imod.addInst("s_cmp_eq_u32", sgpr("ShadowLimit%s+1"%tc), 0, "are we within 2^32?") if self.staggerU: # staggerU case, need to restore BufferLimit when ShadowLimit goes to negative value - imod.addInst("s_cselect_b32", sgpr("Srd%s+2"%tc), sgpr("ShadowLimit%s+0"%tc), "BufferLimit", "Move shadow to real if we are within 2^32") + imod.addInst("s_cselect_b32", sgpr("Srd%s+2"%tc), sgpr("ShadowLimit%s+0"%tc), "BufferLimit%s"%tc, "Move shadow to real if we are within 2^32") else: imod.addInst("s_cmov_b32", sgpr("Srd%s+2"%tc), sgpr("ShadowLimit%s+0"%tc), "Move shadow to real if we are within 2^32") else: @@ -8520,6 +8759,24 @@ def globalReadDo(self, kernel, mode, tP, vregSetIdx=0): return imod + + ############################################################################## + # Global Read A/B completed + ############################################################################## + def doneGlobalABReads(self, kernel): + kStr = "" + kStr += self.comment("Done global A/B reads") + # TODO Many kernels can undefine this, but condition needs to be updated + # Has a problem with tail loop in convolution kernels which have PK elements, but not marked PK + # if kernel["BufferLoad"] and not kernel["PrefetchAcrossPersistent"]: + # kStr += self.undefineSgpr("SrdA") + # kStr += self.undefineSgpr("SrdB") + # TODO Should be able to define WS here but it gets undef'd by endSummation + # if kernel["StreamK"] >= 2: + # self.defineSgpr("SrdWS", 4, 4) + return kStr + + ############################################################################## # Local Write: Swap Offsets A/B ############################################################################## @@ -8718,7 +8975,7 @@ def recalcLocalWriteAddresses(self, kernel, tP, uDu): [self.localWriteStrideTileA, self.localWriteStrideUnrollA] ) tP["localWriteInstruction"] = self.memoryInstructions["LocalWrite"][newInstIdx] - if kernel["PersistentKernel"]: + if kernel["PersistentKernel"] or kernel["StreamK"]: if getattr(self, "oriLwa%s"%tc) is None: setattr(self, "oriLwa%s"%tc, self.vgprPool.checkOut(1, "OriLocalWriteddr%s"%tc) ) kStr += inst("v_mov_b32", vgpr(getattr(self, "oriLwa%s"%tc)), vgpr("LocalWriteAddr%s"%tc), "back up LWA for persistent kernel + wider local read") @@ -8761,15 +9018,17 @@ def recalcLocalReadAddressesAB(self, kernel): # however, we need to update related variables below and regenerate local read instruction based on new numReadsIterCoalesced numReadsIterCoalescedA = self.numReadsIterCoalescedA numReadsIterCoalescedB = self.numReadsIterCoalescedB - if kernel.enabledSplitLDS or (numReadsIterCoalescedA > 1 or numReadsIterCoalescedB > 1): #and tP["isB"]: - self.numReadsIterCoalescedA = 1 - self.numReadsIterCoalescedB = 1 - self.lrvwA = kernel["MIInputPerThread"] - self.lrvwB = kernel["MIInputPerThread"] - kStr = "" + kStr = "" + + needRecalc = kernel.enabledSplitLDS or (numReadsIterCoalescedA > 1 or numReadsIterCoalescedB > 1) #and tP["isB"]: + # backup LocalReadAddr + # LdsPad + LBSPP case, need to backup LocalReadAddr even if recalc is not done + needBackupLRAddr = needRecalc or (kernel["LdsPadA"] and kernel["LdsBlockSizePerPadA"] or kernel["LdsPadB"] and kernel["LdsBlockSizePerPadB"]) + + if needBackupLRAddr: # need to back-up the LRA before reCalculation for wider local read (when no wlr, no need to do this) - if kernel["PersistentKernel"]: + if kernel["PersistentKernel"] or kernel["StreamK"]: if self.oriLraA is None and not kernel["DirectToVgprA"]: # no local read code if DirectToVgpr is enabled self.oriLraA = self.vgprPool.checkOut(1, "OriLocalReadAddrA") kStr += inst("v_mov_b32", vgpr(self.oriLraA), vgpr("LocalReadAddrA"), "back up LRA for persistent kernel + wider local read") @@ -8777,14 +9036,23 @@ def recalcLocalReadAddressesAB(self, kernel): self.oriLraB = self.vgprPool.checkOut(1, "OriLocalReadAddrB") kStr += inst("v_mov_b32", vgpr(self.oriLraB), vgpr("LocalReadAddrB"), "back up LRA for persistent kernel + wider local read") - kStr += (self.lraTileAssignment(kernel, self.tPA, self.tPB)) - kStr += (self.lraFinalOffset(kernel, self.tPA)) - kStr += (self.lraDeclareAddresses(kernel, self.tPA)) - kStr += (self.lraFinalOffset(kernel, self.tPB)) - kStr += (self.lraDeclareAddresses(kernel, self.tPB)) + if needRecalc: + self.numReadsIterCoalescedA = 1 + self.numReadsIterCoalescedB = 1 + self.numIterPerCoalescedReadA = max(1,self.numReadsIterCoalescedA//kernel["InnerUnroll"]) + self.numIterPerCoalescedReadB = max(1,self.numReadsIterCoalescedB//kernel["InnerUnroll"]) + self.lrvwA = kernel["MIInputPerThread"] + self.lrvwB = kernel["MIInputPerThread"] + + kStrRecalc = "" + kStrRecalc += (self.lraTileAssignment(kernel, self.tPA, self.tPB)) + kStrRecalc += (self.lraFinalOffset(kernel, self.tPA)) + kStrRecalc += (self.lraDeclareAddresses(kernel, self.tPA)) + kStrRecalc += (self.lraFinalOffset(kernel, self.tPB)) + kStrRecalc += (self.lraDeclareAddresses(kernel, self.tPB)) if kernel["MatrixInstB"] == 1: # recalc code is necessary only for MatrixInstB=1 - imod.addCode(kStr) + kStr += kStrRecalc localRead2Perpendicular = False instructions = self.memoryInstructions @@ -8808,6 +9076,10 @@ def recalcLocalReadAddressesAB(self, kernel): self.localReadInstructionB = instructions["LocalRead"][ \ self.localReadInstructionIdxB] self.tPB["localReadInstruction"] = self.localReadInstructionB + + if kStr != "": + imod.addCode(kStr) + return str(imod) ############################################################################## @@ -9302,16 +9574,10 @@ def localReadOffsetConvForDTL(self, kernel, tP, offset_val): return offset_val ############################################################################## - # Local Read: Increment A/B + # Local Read: Increment A/B sub function ############################################################################## - def localReadInc(self, kernel, iui, tP): + def localReadIncSub(self, kernel, iui, tP, LdsPad): tc = tP["tensorChar"] - if not self.do["LocalRead%s" % tc] or kernel["DirectToVgpr%s"%tc]: # no local read code if DirectToVgpr is enabled - return "" - - kStr = "" - - LdsPad = kernel["LdsPad%s"%tc] if kernel["LdsBlockSizePerPad%s"%tc] == 0 else 0 # offset increment calculation for both tail loop and not tail loop cases inc_base = (kernel["MacroTile%s" % tP["tensorChar"]] + LdsPad) @@ -9343,6 +9609,23 @@ def localReadInc(self, kernel, iui, tP): else: inc = inc_base_lsu + return inc + + ############################################################################## + # Local Read: Increment A/B + ############################################################################## + def localReadInc(self, kernel, iui, tP): + tc = tP["tensorChar"] + if not self.do["LocalRead%s" % tc] or kernel["DirectToVgpr%s"%tc]: # no local read code if DirectToVgpr is enabled + return "" + + kStr = "" + + LdsPad = kernel["LdsPad%s"%tc] if kernel["LdsBlockSizePerPad%s"%tc] == 0 else 0 + + # offset increment calculation for both tail loop and not tail loop cases + inc = self.localReadIncSub(kernel, iui, tP, LdsPad) + if self.inTailLoop: comment = " (LSU*(MT+PAD)*bpe)" bpe = tP["bpe"] @@ -9363,6 +9646,15 @@ def localReadInc(self, kernel, iui, tP): else: comment = " (LSU*bpe)" inc *= bpe + # adjust inc for LBSPP in TailLoop case + if (kernel["LdsBlockSizePerPad%s"%tc] != 0) and (kernel["LdsPad%s"%tc] != 0): + incTotalPrev = 0 + # acculmulate total inc from 0 to iui-1 + for i in range(iui): + incTotalPrev += self.localReadIncSub(kernel, i, tP, LdsPad) * bpe + extraIncPrev = ((incTotalPrev) // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * bpe + extraIncCurr = ((inc+incTotalPrev) // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * bpe + inc += extraIncCurr - extraIncPrev tmpSgpr = self.getTmpSgpr(1).idx() kStr += inst("s_mov_b32", sgpr(tmpSgpr), hex(inc), "inc") kStr += inst("_v_add_co_u32", \ @@ -9902,7 +10194,8 @@ def localSplitUReduction(self, kernel): def computeStoreSrdStart(self, kernel): kStr = "" - tmpS0 = self.getTmpSgpr(3).idx() + tmpS0ref = self.getTmpSgpr(3) + tmpS0 = tmpS0ref.idx() tmpS1 = tmpS0+1 wgMT1 = tmpS0+2 @@ -10120,14 +10413,16 @@ def localSplitUGlobalWriteIndices(self, kernel): return kStr ############################################################################## - def allocPostLoopSrd(self, kernel, ch): + def allocPostLoopSrd(self, kernel, tc): kStr = "" + sgprAddress0 = sgpr("Address%s+0"%tc) + sgprAddress1 = sgpr("Address%s+1"%tc) # Buffer-load uses one base read pointer stored in the SRD - set it here: if not self.releaseSgprAdressCD: - kStr += inst("s_mov_b32", sgpr("Srd%s+0"%ch), sgpr("Address%s+0"%ch), "init SRD base address (lower)" ) - kStr += inst("s_mov_b32", sgpr("Srd%s+1"%ch), sgpr("Address%s+1"%ch), "init SRD base address (upper) + other fields" ) - kStr += inst("s_mov_b32", sgpr("Srd%s+2"%ch), "BufferOOB", "") - kStr += inst("s_mov_b32", sgpr("Srd%s+3"%ch), "Srd127_96", "Set bits 127_96 in post-loop SRD") + kStr += inst("s_mov_b32", sgpr("Srd%s+0"%tc), sgprAddress0, "init SRD base address (lower)" ) + kStr += inst("s_mov_b32", sgpr("Srd%s+1"%tc), sgprAddress1, "init SRD base address (upper) + other fields" ) + kStr += inst("s_mov_b32", sgpr("Srd%s+2"%tc), "BufferOOB", "") + kStr += inst("s_mov_b32", sgpr("Srd%s+3"%tc), "Srd127_96", "Set bits 127_96 in post-loop SRD") kStr += "\n" return kStr @@ -10611,7 +10906,7 @@ class StoreState: # the generation of the store code. ############################################################################## class StoreConstConfig: - def __init__(self, kernelWriter, kernel, ss, gwvw, edge, beta, atomic): + def __init__(self, kernelWriter, kernel, ss, gwvw, edge, beta, atomic, isWorkspace=False): self.gwvw = gwvw if ss.optSingleColVgpr: @@ -10661,14 +10956,18 @@ def __init__(self, kernelWriter, kernel, ss, gwvw, edge, beta, atomic): and kernel["ProblemType"]["DestDataType"].isHalf() \ and (not kernel["ProblemType"]["HighPrecisionAccumulate"]) + bpeC = kernelWriter.bpeCexternal + if isWorkspace: + bpeC = kernelWriter.bpeCinternal + if atomic: # flat atomics have another VGPR to allow different data for return# regsPerElement = 2 # The atomic loop processes multiple elements in single instruction # so will use VGPR from consec elements? TODO - self.numVgprsPerDataPerVI = (1.0 * regsPerElement * kernelWriter.bpeCexternal) / kernelWriter.bpr + self.numVgprsPerDataPerVI = (1.0 * regsPerElement * bpeC) / kernelWriter.bpr elif beta: - self.numVgprsPerDataPerVI = (1.0 * kernelWriter.bpeCexternal) / kernelWriter.bpr + self.numVgprsPerDataPerVI = (1.0 * bpeC) / kernelWriter.bpr if kernelWriter.HHH_WMMA: self.numVgprsPerDataPerVI = 1.0 else: @@ -10676,6 +10975,7 @@ def __init__(self, kernelWriter, kernel, ss, gwvw, edge, beta, atomic): if kernelWriter.serializedStore: #self.numVgprPerValuC = kernel["MIRegPerOut"] + # TODO check stream-k case self.numVgprPerValuC = kernelWriter.bpeCinternal//kernelWriter.bpr # vgpr needed from register pool else: self.numVgprPerValuC = 0 # null since they are already declared in macro part of assembly kernel @@ -10689,7 +10989,7 @@ def __init__(self, kernelWriter, kernel, ss, gwvw, edge, beta, atomic): self.halfDataRegPerVI = gwvw*self.numVgprsPerDataPerVI < 1.0 and not (kernel["ProblemType"]["UseInitialStridesCD"] and kernelWriter.archCaps["HasEccHalf"]) # StoreState constructor: - def __init__(self, kernelWriter, kernel, gwvw, edge, beta, atomic, elements): + def __init__(self, kernelWriter, kernel, gwvw, edge, beta, atomic, elements, isWorkspace=False): self.kernelWriter = kernelWriter self.kernel = kernel @@ -10740,7 +11040,7 @@ def __init__(self, kernelWriter, kernel, gwvw, edge, beta, atomic, elements): if not atomic and len(kernel["PackedC1IndicesX"]) == 1: self.optSrdIncForRow = 1 - if kernel["StoreRemapVectorWidth"]: + if kernel["StoreRemapVectorWidth"] and not isWorkspace: self.optSrdIncForRow = 1 if kernel["ProblemType"]["UseInitialStridesCD"]: @@ -10755,7 +11055,7 @@ def __init__(self, kernelWriter, kernel, gwvw, edge, beta, atomic, elements): assert (not (self.optSingleColVgpr and self.optSharedColVgpr)) - self.cfg = self.StoreConstConfig(kernelWriter, kernel, self, gwvw, edge, beta, atomic) + self.cfg = self.StoreConstConfig(kernelWriter, kernel, self, gwvw, edge, beta, atomic, isWorkspace) # Use to detect new rows: self.lastCoordOffset1 = 0 @@ -10803,7 +11103,7 @@ def __init__(self, kernelWriter, kernel, gwvw, edge, beta, atomic, elements): # # Also create an AddrCalc for each memory operation. ############################################################################## - def setupStoreElementsForBatch(self, kernel, gwvw, batchElements, batchElementSgprs, preventOverflow, VectorWidthB): + def setupStoreElementsForBatch(self, kernel, gwvw, batchElements, batchElementSgprs, preventOverflow, VectorWidthB, isWorkspace=False): self.elementAddr = [] self.elementData = [] # VGPR to use for element data, needed for atomic or beta @@ -10811,6 +11111,9 @@ def setupStoreElementsForBatch(self, kernel, gwvw, batchElements, batchElementSg self.elementSumIdx = [] kw = self.kernelWriter + dataType = kernel["ProblemType"]["DestDataType"] + if isWorkspace: + dataType = kernel["ProblemType"]["ComputeDataType"] if kernel["EnableMatrixInstructionStore"]: matrixInstM = (kernel["MatrixInstM"] * kernel["MatrixInstBM"]) if (kernel["MatrixInstM"] == 4) else kernel["MatrixInstM"] @@ -10918,10 +11221,10 @@ def setupStoreElementsForBatch(self, kernel, gwvw, batchElements, batchElementSg # TODO- check (H,H,H,H,S,S) # NOTE: Changed from DataType to DestDataType if kernel["ProblemType"]["HighPrecisionAccumulate"] and \ - (kernel["ProblemType"]["DestDataType"].isBFloat16() or kernel["ProblemType"]["DestDataType"].isHalf()): + (dataType.isBFloat16() or dataType.isHalf()): data = kw.vgprPool.checkOutAligned(int(2*self.cfg.numVgprsPerDataPerVI*self.cfg.gwvw), \ int(ceil(int(2*self.cfg.numVgprsPerDataPerVI*self.cfg.gwvw))), "writeBatch-data for ei=%u and ei=%u"%(elementIdx,elementIdx+1), preventOverflow=preventOverflow) - elif kernel["ProblemType"]["DestDataType"].is8bitFloat(): + elif dataType.is8bitFloat(): numRegForData = int(ceil(self.cfg.numVgprsPerDataPerVI*self.cfg.gwvw)) data = kw.vgprPool.checkOutAligned(numRegForData, numRegForData, "writeBatch-data for ei=%u and ei=%u"%(elementIdx,elementIdx+1), preventOverflow=preventOverflow) else: @@ -11153,6 +11456,8 @@ def emitScaleToBpe(self, kernel, ss, tmpVgpr, singleUpdate, tc): """ kStr = "" + if tc == 'WS': # StreamK workspace uses compute size + tc = 'C' kw = self.kernelWriter (d1,d0,vc1,vc0) = self.element rowPtr = kw.cinRowPtr if (tc == 'C') else kw.coutRowPtr @@ -11427,6 +11732,8 @@ def incrementToNextRow(self, kernel, tc, ss, stmp): packedC1 = kernel["PackedC1IndicesX"] assert(len(packedC1) == 1) # would need to extract each dim and scale strideCD1 = "Stride%s%s"%(tc,self.kernelWriter.indexChars[packedC1[0]]) + if tc == 'WS': # TODO StreamK revised large WS + strideCD1 = "Stride%s%s"%('D',self.kernelWriter.indexChars[packedC1[0]]) if numRows > 1: kStr += inst("s_mul_i32", sgpr(stmp), \ sgpr(strideCD1), \ @@ -11459,8 +11766,12 @@ def incrementToNextRow(self, kernel, tc, ss, stmp): # tmpSgpr is one temp sgpr # betaLabel is label to branch to if beta != 0 ############################################################################## - def checkIsBetaZero(self, kernel, tmpSgpr, betaLabel): + def checkIsBetaZero(self, kernel, betaLabel, tmpSgpr=None): kStr = "" + tmpSgprRef = None + if tmpSgpr == None: + tmpSgprRef = self.getTmpSgpr(1) + tmpSgpr = tmpSgprRef.idx() if kernel["ProblemType"]["UseBeta"]: if self.bpeCinternal <= self.bpr: # 1 register to check for Beta==0 kStr += inst("s_cmpk_eq_u32", sgpr("Beta"), hex(0), "Beta == 0") @@ -11479,8 +11790,12 @@ def checkIsBetaZero(self, kernel, tmpSgpr, betaLabel): # tmpSgpr must have at least 6 free SGPR # isEdgeTarget is the branch target if edges are required ############################################################################## - def checkIsEdge(self, kernel, tmpSgpr, isEdgeTarget): + def checkIsEdge(self, kernel, isEdgeTarget, tmpSgpr=None): kStr = "" + tmpSgprRef = None + if tmpSgpr == None: + tmpSgprRef = self.getTmpSgpr(3) + tmpSgpr = tmpSgprRef.idx() tmpS0 = tmpSgpr tmpS1 = tmpS0 + 1 tmpS23 = tmpS1 + 1 @@ -11539,368 +11854,1620 @@ def checkIsEdge(self, kernel, tmpSgpr, isEdgeTarget): return kStr ############################################################################## - # Global Write Elements + # Global Write Batch ############################################################################## - def globalWriteElements(self, kernel, vectorWidths, elements, - applyAlpha=True, # defaults to generating *=alpha codes - betas=None, # if left unspecified, then let global parameter decide - edges=None, - isOptNLL=False): # if OptNLL or not (for StoreCInUnroll) - if not kernel["StoreCInUnroll"]: - if not self.do["PostLoop"]: return "" + def fixupBatch(self, kernel, ss, batchIdx, edge, gwvw, \ + batchElements, coord0, coord1, addrD, addrC, \ + tmpVgpr, tmpCVTVgpr, batchElementSgprs, tmpSgpr, codeAccVgprRead, codeAccVgprWrite): kStr = "" - atomic = (kernel["GlobalSplitU"] > 1) and (kernel["_GlobalAccumulation"] != 'MultipleBuffer') - useCodeMulAlpha = kernel["MIArchVgpr"] and applyAlpha and not (kernel["GlobalSplitU"] > 1) - - # write possibilities and labels - # if beta/edge combo not specified fall back to global param definition - if betas is None: - hasBeta = kernel["ProblemType"]["UseBeta"] and (kernel["GlobalSplitU"] == 1) - betas = [False, True] if hasBeta else [False] - if edges is None: - edges = [False, True] if self.do["EdgeWrite"] else [False] - writeLabels = {} - for beta in betas: - writeLabels[beta] = {} - for edge in edges: - writeLabels[beta]["EdgeCheck0"] = self.getNamedLabelUnique("GW_B%u_E%u_EdgeCheck0" % ( 1 if beta else 0, 1 if edge else 0) ) - writeLabels[beta]["EdgeCheck1"] = self.getNamedLabelUnique("GW_B%u_E%u_EdgeCheck1" % ( 1 if beta else 0, 1 if edge else 0) ) - writeLabels[beta][edge] = self.getNamedLabelUnique("GW_B%u_E%u" % ( 1 if beta else 0, 1 if edge else 0) ) - if not beta: - betaLabel = self.getNamedLabelUnique("GW_Beta") - endLabel = self.getNamedLabelUnique("GW_End") - - # Layout - """ - if B1 goto label_B1 - if E1 goto label_B0_E1 - label_B0_E0: - writes - goto label_End - label_B0_E1: - writes - goto label_End - label_B1: - if E1 goto label_B1_E1 - label_B1_E0: - writes - goto label_End - label_B1_E1: - writes - goto label_End - label_End - """ - self.betaVgpr = None - - ######################################## - # Vgprs - if kernel["BufferStore"]: - numTmpVgpr = 2 - if len(kernel["PackedC0IndicesX"]) > 1: - numTmpVgpr += 1 - else: - numTmpVgpr = 2 + 3 # GLOBAL_OFFSET_C needs 3, plus 2 tmps? - if useCodeMulAlpha and kernel["ProblemType"]["DataType"].isComplex(): - # codeMulAlpha and complex caes, use tmpVgpr for alpha calculation - numTmpVgpr = max(numTmpVgpr, kernel["ProblemType"]["DataType"].numRegisters()) - tmpVgpr = self.vgprPool.checkOutAligned(numTmpVgpr, 2, "store tmps") - - isHpaBF16 = kernel["ProblemType"]["DestDataType"].isBFloat16() and kernel["ProblemType"]["HighPrecisionAccumulate"] - isHpaF8 = kernel["ProblemType"]["DestDataType"].isFloat8() or kernel["ProblemType"]["DestDataType"].isBFloat8() # F8 is always HPA - - # need temp vgpr both for bf16 and f8 - if isHpaF8: - rcnt = 4 - if kernel["ProblemType"]["Fp32toFp8SWClip"]: - rcnt += 1 - if kernel["ProblemType"]["StochasticRounding"]: - rcnt += 2 - tmpCVTVgpr = self.vgprPool.checkOut(rcnt) - elif isHpaBF16: - tmpCVTVgpr = self.vgprPool.checkOut(4) - else: - tmpCVTVgpr = None - - ######################################## - # Sgprs - # allocate tmps for the store header (before the batch implementations) - tmpSgpr = self.getTmpSgpr(4).idx() + kStr += self.comment1("optSingleColVgpr=%u optSharedColVgpr=%u optSGPRUsage=%s optSrdIncForRow=%u" % \ + (ss.optSingleColVgpr, ss.optSharedColVgpr, ss.optSGPRUsage, ss.optSrdIncForRow)) - # branch B1 or B0 - betaLabel = self.getNamedLabelUnique("GW_Beta") + if kernel["StoreSyncOpt"]: + kStr += "s_sleep %d // optimization: sync and wait\n" %(kernel["StoreSyncOpt"]-1) + kStr += "s_barrier\n" - if False in betas and True in betas: - kStr += self.checkIsBetaZero(kernel, tmpSgpr, betaLabel) + # comment tt1, tt0, vc1, vc0 + # tt = thread tile, vc=vector component + commentStr = "Fixup%s Batch #%u (d1,d0,vc1,vc0) =\n " \ + % (" Edge" if edge else "", batchIdx) + for elementIdx in range(0, len(batchElements)): + element = batchElements[elementIdx] + commentStr += "(%u,%u,%u,%u:vw%u)" % \ + (element[0], element[1], element[2], element[3], gwvw) + if elementIdx < len(batchElements)-1: + commentStr += "; " + kStr += self.comment3(commentStr) + # print(self.kernelName) + # print(commentStr) - for beta in betas: - # start B1 - if beta: - kStr += "%s:\n"%(betaLabel) + # allow expanding vgpr pool for OptNLL + preventOverflow = True #(not isOptNLL) + ss.setupStoreElementsForBatch(kernel, gwvw, batchElements, batchElementSgprs, preventOverflow=preventOverflow, \ + VectorWidthB=self.VectorWidthB, isWorkspace=True) - ######################################## - # branch if Edge0 or Edge1 - if False in edges and True in edges: - kStr += self.checkIsEdge(kernel, tmpSgpr, "%s" % writeLabels[beta][True]) + loadsIssued = 0 + storesIssued = 0 + tmpS01 = tmpSgpr # scratch sgprs - # by now we either jumped to E1 or stayed at E0 - for edge in edges: - kStr += "%s:%s"%(writeLabels[beta][edge], self.endLine) + wavelen = self.kernel["WavefrontSize"] + laneSGPRC = self.laneSGPRCount + # always use gwvw for buffer load C for atomic_cmpswap + # bpm = self.bpeCexternal * atomicW + # bpm = self.bpeCexternal * gwvw + # vgprLoadDW = 1*(bpm//4) + # atomic oparation width. 1 for b32, 2 for b64 + # atomicOpW = (atomicW * self.bpeCexternal) // 4 + # if atomicOpW > 2: + # # should not exceeding 2. + # atomicOpW = 2 - PreLoopVmcntCaseStr = "" - # not generate Case 2 if StoreCInUnroll with StoreVectorWidth==1 (Case 2 will be same as Case 3) - if self.canOptimizePreLoopLWVmcnt: - if beta: - self.currPreLoopVmcntCase = PreLoopVmcntCase.OrdNLL_B1_Store - elif edge or (kernel["StoreCInUnroll"] and kernel["StoreVectorWidth"]==1): - self.currPreLoopVmcntCase = PreLoopVmcntCase.OrdNLL_E1_Store - else: - self.currPreLoopVmcntCase = PreLoopVmcntCase.OptNLL_Store - PreLoopVmcntCaseStr = inst("s_mov_b32", sgpr("PreLoopLWVmcntCase"), hex(self.currPreLoopVmcntCase.value), \ - "for optimizing next PreLoop LW vmcnt, set to Case%u"%self.currPreLoopVmcntCase.value) - # reset vmcnt if the dict has this key (OptNLL_Store, OrdNLL_E1_Store), - # OrdNLL_B1_Store is excluded - if self.currPreLoopVmcntCase in self.preLoopVmcntDict: - self.preLoopVmcntDict[self.currPreLoopVmcntCase] = 0 - - # for storeRemap edge case, non-beta still can enable vector stores - if kernel["StoreRemapVectorWidth"] and not beta: - edgeI = False - else: - edgeI = edge - #edgeI = True # set to True to disable vector stores - gwvw = vectorWidths[edgeI] - #print "globalWriteElements: edge=", edge, "beta=", beta, "atomic=", atomic + ######################################## + # calculate addr and masks + kStr += self.comment("calc coords, apply mask, and issue loads (if necessary)") + # On input, coord0 and coord1 are VGPRs computed in the pre-batch code, based + # on the thread and tid number. These are ELEMENT offsets from start of tensor C + # for the top-left corner this thread will write. These are not changed + # across all the store loop iters. + if self.db["ConservativeWaitCnt"] & 0x10: + kStr += "s_barrier // debug\n" + kStr += inst("s_waitcnt", "vmcnt(0)", "ConservativeWaitCnt" ) + if self.archCaps["SeparateVscnt"]: + kStr += inst("s_waitcnt_vscnt", "null", "0", "writes") + kStr += "s_barrier // debug\n" + if not edge and self.db["ForceEdgeStores"]>=2: + kStr += self.bomb() # should not get here + if edge and self.db["AssertNoEdge"]: + kStr += self.bomb() # should not get here - ######################################## - # Calculate Vgprs for Write Batching - ######################################## + atomicAddC = kernel["AtomicAddC"] and not edge - self.ss = self.StoreState(self, kernel, gwvw, edge, beta, atomic, elements[edgeI]) + ## create code Module to push mov vgpr,acc instructions + if kernel["StoreCInUnroll"] and not edge: + accVgprRead = Code.Module("movaccVgpr") + self.StoreCUnrollLoadCWaitComment = "waitcnt for LoadC" # this will be used later to identify waitcnt for loadC - # how many vgprs are needed for zero elements - # 2 for addressC in vgpr for addition - already checked out - # 2 for coord0,1 of thread - already checked out - # 2 for tmp - already checked out + for elementIdx in range(0, len(batchElements)): + element = batchElements[elementIdx] + addrCVgpr = ss.elementAddr[elementIdx].addrCVgpr + addrDVgpr = ss.elementAddr[elementIdx].addrDVgpr + addrCalc = ss.elementAddr[elementIdx] + data = ss.elementData[elementIdx] + mask = ss.elementMask[elementIdx] + sumIdx = ss.elementSumIdx[elementIdx] + # d1 = element[0] + # d0 = element[1] + # vc1 = element[2] + vc0 = element[3] - # 5 = how many vgprs are needed per element (flat) - # - 2 for addr - # - 3 for GLOBAL_OFFSET_C calculation (can overlap below, therefore max) - # - if beta gwvw*rpe for new value - # - if atomic 2*rpe for old and cmp values + storeWidth = kernel["StoreVectorWidth"] + # storeWidth = 2 + if batchIdx == 0 and elementIdx == 0: + kStr += staticMultiply(vgpr(addrCVgpr), vgpr("Serial"), storeWidth * self.bpeCinternal, sgpr(tmpS01)) + # kStr += inst("v_mul_lo_u32", , "Partials buffer address") + kStr += inst("s_mov_b32", sgpr(tmpS01), 0, "Init sgpr offset") + else: + increment = (kernel["WavefrontSize"] * 4) * storeWidth * self.bpeCinternal + kStr += inst("s_add_u32", sgpr(tmpS01), sgpr(tmpS01), increment, "Inc sgpr offset") - # print("numVgprsPerAddr=%u, numVgprsPerDataPerVI=%u, numVgprPerValuC=%u"%(self.ss.cfg.numVgprsPerAddr, self.ss.cfg.numVgprsPerDataPerVI, self.ss.cfg.numVgprPerValuC)) - numVgprsPerElement = self.ss.cfg.numVgprPerValuC*gwvw + self.ss.cfg.numVgprsPerAddr + int(ceil(self.ss.cfg.numVgprsPerDataPerVI * gwvw)) + kStr += self.readCInput(kernel, ss, addrCalc, vc0, data, gwvw, addrCVgpr, sgpr(tmpS01), 'WS') + loadsIssued += 1 - if kernel["GroupLoadStore"] and kernel["ProblemType"]["UseBeta"]: - numVgprsPerElement += self.ss.cfg.numVgprsPerAddr - - #print self.vgprPool.state() - # Use VGPR up to next occupancy threshold: - maxVgprs = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), \ - self.getLdsSize(kernel), self.agprPool.size(), self.doubleVgpr) - if self.serializedStore: # get aggressive when serializedStore is on; not necessarily exclusive to this parameter - len(elements[edgeI]) - tl = [] - for i in range(self.vgprPool.size()-self.vgprPool.available(), maxVgprs): - tl.append(self.vgprPool.checkOut(1, "grow-pool up to next occupancy for GlobalWrite")) + ######################################## + # AccVgpr read + if kernel.enabledSetPrioSplitLDS: + kStr += inst("s_setprio", "0", "") + if codeAccVgprRead is not None: + regsPerScalar = self.bpeCinternal//self.bpr # register per scalar + # loop over store instructions within one batch + for elementIdx in range(0, len(batchElements)): + # loop over scalars within one store instruction + for vi in range(0, gwvw): + # loop over registers within one scalar + for rIdx in range(0, regsPerScalar): + tempStr = str(codeAccVgprRead.items().pop(0)) + kStr += tempStr.replace("__placeholder__", str(ss.elementSumIdx[elementIdx]*regsPerScalar + regsPerScalar*vi + rIdx)) + if kernel["StoreCInUnroll"] and not edge: + tempStr = tempStr.replace("__placeholder__",str(elementIdx*gwvw*regsPerScalar + regsPerScalar*vi + rIdx)) + accVgprRead.addCode(tempStr.replace("ValuC","L2GC")) + + if not kernel["MIArchVgpr"]: + kStr += inst("s_nop 1", "2 wait states required before reading vgpr") + + ######################################## + # Not Atomic + ######################################## + # edge has v_cndmask so loads or stores may not issue, hard to track vmcnt: + interleaveStoreVmcnt = self.interleaveStoreVmcnt and not edge + for elementIdx in range(0, len(batchElements)): + for vi in range(0, gwvw): + sumIdxV = ss.elementSumIdx[elementIdx] + vi + # covers sgemm, gemm_ex(HHS/HSS/BBS/BSS (HPA=T)), int8 (int8x4?) + if kernel["ProblemType"]["ComputeDataType"].isInt32() or \ + kernel["ProblemType"]["ComputeDataType"].isSingle(): # covers sgemm/gemm_ex(HHS/HSS/BBS/BSS) + if self.db["ForceExpectedValue"]: + kStr += inst("v_mov_b32", vgpr("ValuC+%u"%sumIdxV), self.db["ValueCExpectedValue"], "force expected value" ) + if self.db["ForceVSerial"]: + kStr += inst("v_mov_b32", vgpr("ValuC+%u"%sumIdxV), vgpr("Serial"), "force expected value to serial" ) + if self.db["CheckValueC"]: + kStr += inst("s_mov_b32", sgpr(tmpS01), self.db["ValueCExpectedValue"], "Move expected value") + kStr += self.assert_eq(vgpr("ValuC+%u"%sumIdxV), sgpr(tmpS01)) + + ######################################## + # wait for batched load + if not interleaveStoreVmcnt: # beta and + kStr += inst("s_waitcnt", "vmcnt(0)", "wait C") + if self.archCaps["SeparateVscnt"]: + kStr += inst("s_waitcnt_vscnt", "null", "0", "writes") + + # PreLoop LWVmcnt: When a vmcnt(cnt) is inserted here, means the GlobalLoad for PAP is finished + # So the preLoopVmcntDict value is meaningless since we no longer need to wait in next PreLoop + # And this only occurs when beta=true, so case must not be 2 or 3 + assert self.currPreLoopVmcntCase not in self.preLoopVmcntDict, \ + "PreLoopVmcntCase 2 or 3 shouldn't enter the beta true case" + + kStr += self.comment("apply mask, calc new C and issue writes") + #kStr += self.bomb() # can see store addresses just before the store inst + + if kernel["ProblemType"]["DestDataType"].isBFloat16() and kernel["ProblemType"]["HighPrecisionAccumulate"]: + vgprBf16Temp = tmpCVTVgpr + vgprBf16Mask = vgprBf16Temp + 1 + vgprFp32Nan = vgprBf16Temp + 2 + vgprBf16Inc = vgprBf16Temp + 3 + kStr += inst("v_mov_b32", vgpr(vgprBf16Mask), "0xffff0000", "mask for pack two bfloat16 element to 32bit" ) + kStr += inst("v_mov_b32", vgpr(vgprFp32Nan), "0x7fff0000", "fp32 Nan" ) + kStr += inst("v_mov_b32", vgpr(vgprBf16Inc), "0x7fff", "rounding bias for bfloat16" ) + + # DestDataType for 8bit Float can only be F8 or B8 + if kernel["ProblemType"]["DestDataType"].isFloat8() or kernel["ProblemType"]["DestDataType"].isBFloat8(): # F8 is always HPA + # make vgprF8Temp0 always even to use pk instruction later + if tmpCVTVgpr % 2 == 0: + vgprF8Temp0 = tmpCVTVgpr + vgprF8Temp1 = vgprF8Temp0 + 1 + vgprF8Max = vgprF8Temp0 + 2 + vgprF8Min = vgprF8Temp0 + 3 + else: + vgprF8Max = tmpCVTVgpr + vgprF8Temp0 = vgprF8Max + 1 + vgprF8Temp1 = vgprF8Max + 2 + vgprF8Min = vgprF8Max + 3 + + if kernel["ProblemType"]["Fp32toFp8SWClip"]: + # set flag of f32 NaN and +/- INF for v_cmp_class + vgprFp32NanInfFlag = vgprF8Min + 1 + kStr += inst("v_mov_b32", vgpr(vgprFp32NanInfFlag), "0x207", "flag for Nan and +/- inf" ) + # set max/min values for clipping + if kernel["ProblemType"]["DestDataType"].isFloat8(): + kStr += inst("v_mov_b32", vgpr(vgprF8Max), "0x43700000", "save 240.0f as max for clipping" ) + kStr += inst("v_mov_b32", vgpr(vgprF8Min), "0xC3700000", "save -240.0f as min for clipping" ) + else: #BFloat8 + kStr += inst("v_mov_b32", vgpr(vgprF8Max), "0x47600000", "save 57344.0f as max for clipping" ) + kStr += inst("v_mov_b32", vgpr(vgprF8Min), "0xC7600000", "save -57344`.0f as min for clipping" ) + + for elementIdx in range(0, len(batchElements)): + element = batchElements[elementIdx] + addr = ss.elementAddr[elementIdx].addrDVgpr + mask = ss.elementMask[elementIdx] + addrCalc = ss.elementAddr[elementIdx] + # d1 = element[0] + # d0 = element[1] + # vc1 = element[2] + vc0 = element[3] + sumIdx = ss.elementSumIdx[elementIdx] + + # apply in-bounds exec mask + if edge and not kernel["BufferStore"]: + kStr += inst("s_mov_b{}".format(wavelen), self.exec, sgpr(mask,laneSGPRC), "sgprs -> exec" ) + + # if beta: + # if GWVW=1 the half path still assumes we have + # at least two stores so does some combining across VI - + # for example assuming we can have two elements and can use pk_mul + # here: + if interleaveStoreVmcnt: # beta and + if self.archCaps["SeparateVscnt"]: + vmcnt = loadsIssued - elementIdx - 1 + vmComment = "{} = {} - {} - 1".format(vmcnt, loadsIssued, elementIdx) + else: + waitStoreCnt = storesIssued if not kernel["GroupLoadStore"] else 0 + vmcnt = loadsIssued - elementIdx + waitStoreCnt - 1 + vmComment = "{} = {} - {} + {} - 1".format(vmcnt, loadsIssued, elementIdx, waitStoreCnt) + + maxVmcnt = globalParameters["AsmCaps"][self.version]["MaxVmcnt"] + vmcnt = min(vmcnt, maxVmcnt) + #print "wmvcnt=", vmcnt + kStr += "\n" + if not atomicAddC: + kStr += inst("s_waitcnt", "vmcnt(%u)"%vmcnt, "wait C (interleaved) " + vmComment) + + # PreLoop LWVmcnt: When a vmcnt(cnt) is inserted here, means the GlobalLoad for PAP is finished + # So the preLoopVmcntDict value is meaningless since we no longer need to wait in next PreLoop + # And this only occurs when beta=true, so case must not be 2 or 3 + assert self.currPreLoopVmcntCase not in self.preLoopVmcntDict, \ + "PreLoopVmcntCase 2 or 3 shouldn't enter the beta true case" + + for vi in range(0, gwvw): + dataV = ss.elementData[elementIdx] + int(vi*ss.cfg.numVgprsPerDataPerVI) + sumIdxV = ss.elementSumIdx[elementIdx] + vi + if kernel["ProblemType"]["ComputeDataType"].isHalf(): + if not kernel["ProblemType"]["HighPrecisionAccumulate"]: + if self.asmCaps["HasWMMA"] and kernel["EnableMatrixInstructionStore"]: + dataV = ss.elementData[elementIdx] + int(vi / 2 * ss.cfg.numVgprsPerDataPerVI) + # if (vi % 2) == 0: + # kStr += inst("v_pk_mul_f16", vgpr(dataV), sgpr("Beta"), vgpr(dataV+0), \ + # "%s = C*beta ei=%u vi=%u"%(vgpr(dataV),elementIdx, vi)) + # else: + if (vi % 2) != 0: + kStr += inst("v_lshrrev_b32", vgpr(dataV), 16, vgpr(dataV), \ + "shift 16bit to get next half of packed ValueC") + # dataV+0 = new c = old c*beta + rC + kStr += inst("v_pk_add_f16", vgpr("ValuC+%u"%(sumIdxV)), vgpr(dataV), vgpr("ValuC+%u"%(sumIdxV)), \ + "sum*alpha + C*beta") + elif sumIdxV%2==0 or (not self.ss.cfg.halfDataRegPerVI and gwvw==1): + # dataV+0 = new c = old c*beta + # kStr += inst("v_pk_mul_f16", vgpr(dataV), sgpr("Beta"), vgpr(dataV+0), \ + # "%s = C*beta ei=%u vi=%u"%(vgpr(dataV),elementIdx, vi)) + # dataV+0 = new c = old c*beta + rC + kStr += inst("v_pk_add_f16", vgpr("ValuC+%u"%(sumIdxV//2)), vgpr(dataV), vgpr("ValuC+%u"%(sumIdxV//2)), \ + "sum*alpha + C*beta") + else: + pass # add will have been done previously + else: # HPA + # dataV+0 = new c = old c*beta + rC + # src0 = beta = f32 = opsel 00 + # src1 = dataV = f16.lo = opsel 10 or 11 depending on even/odd + # src2 = sumIdxV = f32 = opsel 00 + dataCExternal = ss.elementData[elementIdx] + vi//2 + hi16 = (vi + gwvw*vc0) % 2 + # TODO try to replace with add? need opsel for f16 src + # kStr += inst(self.mixinst, vgpr("ValuC+%u"%sumIdxV), sgpr("Beta"), \ + kStr += inst(self.mixinst, vgpr("ValuC+%u"%sumIdxV), 1, \ + vgpr(dataCExternal), vgpr("ValuC+%u"%sumIdxV), \ + "op_sel:[0,%u,0] op_sel_hi:[0,1,0]" % (hi16), \ + "//C*=beta") + + elif kernel["ProblemType"]["ComputeDataType"].isBFloat16(): + if kernel["ProblemType"]["HighPrecisionAccumulate"]: + # dataV+0 = new c = old c*beta + rC + # src0 = beta = f32 = opsel 00 + # src1 = dataV = f16.lo = opsel 10 or 11 depending on even/odd + # src2 = sumIdxV = f32 = opsel 00 + dataCExternal = ss.elementData[elementIdx] + vi//2 + if (vi%2) == 1: + kStr += inst("v_and_b32", vgpr(tmpVgpr), vgpr(dataCExternal), vgpr(vgprBf16Mask), "convert bf16 to fp32") + else: + kStr += inst("v_lshlrev_b32", vgpr(tmpVgpr), "16", vgpr(dataCExternal), "convert bf16 to fp32" ) + kStr += inst("v_add_f32", vgpr("ValuC+%u"%sumIdxV), vgpr("ValuC+%u"%sumIdxV), vgpr(tmpVgpr), "accum partials") + + # 8bit-Float: dest can only be either F8 or B8 + elif kernel["ProblemType"]["ComputeDataType"].isFloat8() or kernel["ProblemType"]["ComputeDataType"].isBFloat8(): # F8 is always HPA + # dataV+0 = new c = old c*beta + rC + dataCExternal = ss.elementData[elementIdx] + vi//4 + + #restructuring the code to handle edge case + if kernel["ProblemType"]["Fp8NoPackUpConversion"] or gwvw == 1: + byteSel = "src0_sel:BYTE_0" if vi%4==0 else "src0_sel:BYTE_1" if vi%4==1 else "src0_sel:BYTE_2" if vi%4==2 else "src0_sel:BYTE_3" + if kernel["ProblemType"]["DestDataType"].isFloat8(): + kStr += "v_cvt_f32_fp8 v%u, %s %s // convert fp8 in lo_byte[0] to f32%s"%(vgprF8Temp0, vgpr(dataCExternal), byteSel, self.endLine) + else: + kStr += "v_cvt_f32_bf8 v%u, %s %s // convert bf8 in lo_byte[0] to f32%s"%(vgprF8Temp0, vgpr(dataCExternal), byteSel, self.endLine) + kStr += inst("v_add_f32", vgpr("ValuC+%u"%sumIdxV), vgpr("ValuC+%u"%sumIdxV), vgpr(vgprF8Temp0), "accum partials") + else: ## use packing inst + if kernel["ProblemType"]["DestDataType"].isFloat8(): + f32Tof8PkInst="V_cvt_pk_f32_fp8" + else: #bf8 + f32Tof8PkInst="V_cvt_pk_f32_bf8" + + if (vi%2) == 1: ## vi%4 == 1 or vi%4 == 3 + wordSel = "" if vi%4 == 1 else "src0_sel:WORD_1" + wordTxt = "lo_16" if vi%4==1 else "hi_16" + kStr += f32Tof8PkInst + " v[%u:%u], %s %s // convert two f8 in %s to f32%s"%(vgprF8Temp0, vgprF8Temp1, vgpr(dataCExternal), wordSel, wordTxt, self.endLine) + + kStr += inst("v_add_f32", vgpr("ValuC+%u"%(sumIdxV-1)), vgpr("ValuC+%u"%(sumIdxV-1)), vgpr(vgprF8Temp0), "accum partials") + kStr += inst("v_add_f32", vgpr("ValuC+%u"%sumIdxV), vgpr("ValuC+%u"%sumIdxV), vgpr(vgprF8Temp1), "accum partials") + + elif kernel["ProblemType"]["ComputeDataType"].isSingle(): + kStr += inst("v_add_f32", vgpr("ValuC+%u"%sumIdxV), vgpr("ValuC+%u"%sumIdxV), vgpr(dataV+0), "accum partials") + + elif kernel["ProblemType"]["ComputeDataType"].isInt32(): + # assume we will need to replace v_mac_f32 with v_add_u32 and s_mul_lo_i32 + # v_mad_i32_i24 + kStr += inst("_v_add_u32", vgpr("ValuC+%u"%sumIdxV), vgpr(dataV+0), vgpr("ValuC+%u"%sumIdxV), \ + "accum partials") + + elif kernel["ProblemType"]["ComputeDataType"].isDouble(): + # dataV+0 = new c = old c*beta + if not atomicAddC: + kStr += inst("v_add_f64", vgpr("ValuC+%u"%(sumIdxV*2),2), vgpr("ValuC+%u"%(sumIdxV*2),2), vgpr(dataV+0,2), "accum partials") + + # single precision complex + elif kernel["ProblemType"]["ComputeDataType"].isSingleComplex(): + kStr += inst("v_add_f32", vgpr("ValuC+%u"%(sumIdxV*2)), vgpr("ValuC+%u"%(sumIdxV*2)), vgpr(dataV+0), "accum partials real") + kStr += inst("v_add_f32", vgpr("ValuC+%u"%(sumIdxV*2+1)), vgpr("ValuC+%u"%(sumIdxV*2+1)), vgpr(dataV+1), "accum partials imag") + + # double precision complex + elif kernel["ProblemType"]["ComputeDataType"].isDoubleComplex(): + kStr += inst("v_add_f64", vgpr("ValuC+%u"%(sumIdxV*4+0),2), vgpr("ValuC+%u"%(sumIdxV*4+0),2), vgpr(dataV+0,2), "accum partials real") + kStr += inst("v_add_f64", vgpr("ValuC+%u"%(sumIdxV*4+2),2), vgpr("ValuC+%u"%(sumIdxV*4+2),2), vgpr(dataV+2,2), "accum partials imag") + + ######################################## + # AccVgpr write + if kernel.enabledSetPrioSplitLDS: + kStr += inst("s_setprio", "0", "") + if codeAccVgprWrite is not None: + regsPerScalar = self.bpeCinternal//self.bpr # register per scalar + # loop over store instructions within one batch + for elementIdx in range(0, len(batchElements)): + # loop over scalars within one store instruction + for vi in range(0, gwvw): + # loop over registers within one scalar + for rIdx in range(0, regsPerScalar): + tempStr = str(codeAccVgprWrite.items().pop(0)) + kStr += tempStr.replace("__placeholder__", str(ss.elementSumIdx[elementIdx]*regsPerScalar + regsPerScalar*vi + rIdx)) + if kernel["StoreCInUnroll"] and not edge: + tempStr = tempStr.replace("__placeholder__",str(elementIdx*gwvw*regsPerScalar + regsPerScalar*vi + rIdx)) + accVgprRead.addCode(tempStr.replace("ValuC","L2GC")) + + if not kernel["MIArchVgpr"]: + kStr += inst("s_nop 1", "2 wait states required before reading vgpr") + + #kStr += self.bomb(5) + if self.db["CheckStoreC"]>=0: + useBuffer = kernel["BufferStore"] + # Note - CheckStoreC won't work for EDGE store cases since they load 0 for OOB, would need more sophisticated check + # Note - TODO- CheckStoreC also won't work for StoreRemap + kStr += inst("s_waitcnt", "vmcnt(0)", "CheckStoreC, wait for stores to complete" ) + if self.archCaps["SeparateVscnt"]: + kStr += inst("s_waitcnt_vscnt", "null", "0", "writes") + for elementIdx in range(0, len(batchElements)): + addr = ss.elementAddr[elementIdx].addrDVgpr + sumIdx = ss.elementSumIdx[elementIdx] + + bps = kernel["ProblemType"]["DestDataType"].numBytes() * gwvw + if kernel["BufferStore"]: + addr0 = vgpr(addr) + addr1 = sgpr("SrdC", 4) + else: + addr0 = vgpr(addr,2) + addr1 = "" + + if kernel["ProblemType"]["DestDataType"].isHalf() or kernel["ProblemType"]["DestDataType"].isBFloat16(): + if not kernel["ProblemType"]["HighPrecisionAccumulate"]: + kStr += self.chooseGlobalRead(useBuffer, bps, sumIdx//2, \ + addr0, addr1, soffset=0, offset=0, extraFields="", dtlNoDestVgpr=False, hi16=sumIdx%2).toStr() + else: + kStr += self.chooseGlobalRead(useBuffer, bps, sumIdx, \ + addr0, addr1, soffset=0, offset=0, extraFields="", dtlNoDestVgpr=False, hi16=0).toStr() + elif kernel["ProblemType"]["DestDataType"].isInt32() or kernel["ProblemType"]["DestDataType"].isSingle(): + kStr += self.chooseGlobalRead(useBuffer, bps, sumIdx, \ + addr0, addr1, soffset=0, offset=0, extraFields="", dtlNoDestVgpr=False).toStr() + elif kernel["ProblemType"]["DestDataType"].isDouble() or kernel["ProblemType"]["DestDataType"].isSingleComplex() : + kStr += self.chooseGlobalRead(useBuffer, bps, sumIdx*2, \ + addr0, addr1, soffset=0, offset=0, extraFields="", dtlNoDestVgpr=False).toStr() + elif kernel["ProblemType"]["DestDataType"].isDoubleComplex(): + kStr += self.chooseGlobalRead(useBuffer, bps, sumIdx*4, \ + addr0, addr1, soffset=0, offset=0, extraFields="", dtlNoDestVgpr=False).toStr() + kStr += inst("s_waitcnt", "vmcnt(0)", "CheckStoreC, wait for stores to complete" ) + if self.archCaps["SeparateVscnt"]: + kStr += inst("s_waitcnt_vscnt", "null", "0", "writes") + + # Add checks for expected values: + kStr += inst("s_mov_b32", sgpr(tmpS01), self.db["CheckStoreC"], "expected value") + for elementIdx in range(0, len(batchElements)): + sumIdx = ss.elementSumIdx[elementIdx] + # Need to fix for other types: + assert (kernel["ProblemType"]["DestDataType"].isSingle() or kernel["ProblemType"]["DestDataType"].isInt32()) + kStr += self.assert_eq(vgpr(sumIdx), sgpr(tmpS01)) + + + if edge and (not kernel["BufferStore"]): # atomic or + # subsequent batch must start with full exec mask + # BufferStore doesn't need exec since it used buffer range checking when + # possible + kStr += inst("s_mov_b{}".format(wavelen), self.exec, -1, "full mask -> exec" ) + + if self.db["ConservativeWaitCnt"] & 0x40: + kStr += "s_barrier // debug\n" + kStr += inst("s_waitcnt", "vmcnt(0)", "ConservativeWaitCnt" ) + if self.archCaps["SeparateVscnt"]: + kStr += inst("s_waitcnt_vscnt", "null", "0", "writes") + kStr += "s_barrier // debug\n" + ######################################## + # End Not Atomic + ######################################## + + # return registers to pool: + lastData = -1 + for elementIdx in range(0, len(batchElements)): + if not ss.sharedColDVgprs: + addrDVgpr = ss.elementAddr[elementIdx].addrDVgpr + addrCVgpr = ss.elementAddr[elementIdx].addrCVgpr + self.vgprPool.checkIn(addrDVgpr) + if addrCVgpr != addrDVgpr: + self.vgprPool.checkIn(addrCVgpr) + + data = ss.elementData[elementIdx] + if data != 0: + if data != lastData: + self.vgprPool.checkIn(data) + lastData = data + + self.ss.firstBatch = False + self.ss.checkInTempVgprC() + + if self.serializedStore: + kStr += inst("s_nop 0", "1 wait state required when next inst writes vgprs held by previous dwordx4 store inst") + + # Update the store cnt to preLoopVmcntDict for Case2/3 + # (No need to update for Case0:'Undefined' or Case4:'OrdNLL_B1_Store') + if self.currPreLoopVmcntCase in self.preLoopVmcntDict: + if not self.archCaps["SeparateVscnt"]: + self.preLoopVmcntDict[self.currPreLoopVmcntCase] += storesIssued + + return kStr + + ############################################################################## + # Workspace SRD for FixupStep + ############################################################################## + def computeWorkspaceSrd(self, kernel, sCtaIdx, tmpSgpr = None): + kStr = "" + + # Base Address + kStr += inst("s_mov_b32", sgpr("SrdWS+0"), sgpr("AddressWS+0"), "init SRD base address (lower)" ) + kStr += inst("s_mov_b32", sgpr("SrdWS+1"), sgpr("AddressWS+1"), "init SRD base address (upper) + other fields" ) + kStr += inst("s_mov_b32", sgpr("SrdWS+2"), "BufferOOB", "") + kStr += inst("s_mov_b32", sgpr("SrdWS+3"), "Srd127_96", "Set bits 127_96 in post-loop SRD") + + tmpSgprRef = None + if tmpSgpr == None: + tmpSgprRef = self.getTmpSgpr(1) + tmpSgpr = tmpSgprRef.idx() + + assert kernel["BufferStore"] + kStr += "\n" + kStr += inst("s_mul_i32", sgpr(tmpSgpr), hex(kernel["MacroTile0"]*kernel["MacroTile1"]*self.bpeCinternal), sCtaIdx, "Offset to correct partials tile") + kStr += inst("s_add_u32", sgpr("SrdWS+0"), sgpr("SrdWS+0"), sgpr(tmpSgpr), "add lo to SRD") + kStr += inst("s_addc_u32", sgpr("SrdWS+1"), sgpr("SrdWS+1"), 0, "add hi to SRD") + + return kStr + + ############################################################################## + # Fixup Step + ############################################################################## + def fixupStep(self, kernel, vectorWidths, elements, edges, tmpVgpr, tmpCVTVgpr, sCtaIdx, skStoreLabel): + kStr = "" + + edgeLabels = {} + for edge in edges: + edgeLabels[edge] = self.getNamedLabelUnique("Fixup_E%u" % (1 if edge else 0)) + # branch if Edge0 or Edge1 + if False in edges and True in edges: + kStr += self.checkIsEdge(kernel, "%s" % edgeLabels[True]) + + # by now we either jumped to E1 or stayed at E0 + for edge in edges: + # write label for batch case + kStr += "%s:%s"%(edgeLabels[edge], self.endLine) + + PreLoopVmcntCaseStr = "" + # not generate Case 2 if StoreCInUnroll with StoreVectorWidth==1 (Case 2 will be same as Case 3) + if self.canOptimizePreLoopLWVmcnt: + if edge or (kernel["StoreCInUnroll"] and kernel["StoreVectorWidth"]==1): + self.currPreLoopVmcntCase = PreLoopVmcntCase.OrdNLL_E1_Store + else: + self.currPreLoopVmcntCase = PreLoopVmcntCase.OptNLL_Store + PreLoopVmcntCaseStr = inst("s_mov_b32", sgpr("PreLoopLWVmcntCase"), hex(self.currPreLoopVmcntCase.value), \ + "for optimizing next PreLoop LW vmcnt, set to Case%u"%self.currPreLoopVmcntCase.value) + # reset vmcnt if the dict has this key (OptNLL_Store, OrdNLL_E1_Store), + # OrdNLL_B1_Store is excluded + if self.currPreLoopVmcntCase in self.preLoopVmcntDict: + self.preLoopVmcntDict[self.currPreLoopVmcntCase] = 0 + + edgeI = edge + #edgeI = True # set to True to disable vector stores + gwvw = vectorWidths[edgeI] + + ######################################## + # Calculate Vgprs for Write Batching + ######################################## + + self.ss = self.StoreState(self, kernel, gwvw, edge, True, False, elements[edgeI], isWorkspace=True) + + # how many vgprs are needed for zero elements + # 2 for addressC in vgpr for addition - already checked out + # 2 for coord0,1 of thread - already checked out + # 2 for tmp - already checked out + + # 5 = how many vgprs are needed per element (flat) + # - 2 for addr + # - 3 for GLOBAL_OFFSET_C calculation (can overlap below, therefore max) + # - if beta gwvw*rpe for new value + # - if atomic 2*rpe for old and cmp values + + # print("numVgprsPerAddr=%u, numVgprsPerDataPerVI=%u, numVgprPerValuC=%u"%(self.ss.cfg.numVgprsPerAddr, self.ss.cfg.numVgprsPerDataPerVI, self.ss.cfg.numVgprPerValuC)) + numVgprsPerElement = self.ss.cfg.numVgprPerValuC*gwvw + self.ss.cfg.numVgprsPerAddr + int(ceil(self.ss.cfg.numVgprsPerDataPerVI * gwvw)) + + if kernel["GroupLoadStore"] and kernel["ProblemType"]["UseBeta"]: + numVgprsPerElement += self.ss.cfg.numVgprsPerAddr + + #print self.vgprPool.state() + # Use VGPR up to next occupancy threshold: + maxVgprs = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), \ + self.getLdsSize(kernel), self.agprPool.size(), self.doubleVgpr) + if self.serializedStore: # get aggressive when serializedStore is on; not necessarily exclusive to this parameter + len(elements[edgeI]) + tl = [] + for i in range(self.vgprPool.size()-self.vgprPool.available(), maxVgprs): + tl.append(self.vgprPool.checkOut(1, "grow-pool up to next occupancy for GlobalWrite")) + for t in tl: + self.vgprPool.checkIn(t) + align = 1 + # align adjustment + if self.ss.cfg.numVgprsPerAddr > 1: + align = max(align, self.ss.cfg.numVgprsPerAddr) + if self.ss.cfg.numVgprPerValuC*gwvw > 1: + align = max(align, self.ss.cfg.numVgprPerValuC*gwvw) + if int(ceil(self.ss.cfg.numVgprsPerDataPerVI * gwvw)) > 1: + align = max(align, int(ceil(self.ss.cfg.numVgprsPerDataPerVI * gwvw))) + numVgprAvailable = self.vgprPool.availableBlock(numVgprsPerElement, align) + + # Grow the register pool if needed - we need enough regs for at least one element + # Unfortunate since this means the write logic is setting the VGPR requirement + # for the entire kernel but at least we have a functional kernel. + # Before growing the pool, see if we can shrink the write vector width instead? + # TODO : the vgprSerial is needed for-ever and if we grow here will split the + # range of the tmps. Maybe want to move vgprSerial to first vgpr? + + # TODO: Minimum elems for StoreRemap + # TODO: Which of DataType or DestDataType is in a better sense? 0114: Check Using DestDataType + HSS + minElements = 1 + if kernel["ProblemType"]["DataType"].isHalf() or kernel["ProblemType"]["DataType"].isBFloat16(): + minElements = 2 + elif kernel["ProblemType"]["DataType"].is8bitFloat(): + minElements = 4 + + minNeeded = minElements * numVgprsPerElement + shrinkDb = 0 + if shrinkDb: + print("numVgprAvailable=", numVgprAvailable, "minElements=", minElements, "minNeeded=", minNeeded) + if numVgprAvailable < minNeeded: + gwvwOrig = gwvw + currentOccupancy = self.getOccupancy(kernel["NumThreads"], self.getLdsSize(kernel), \ + self.vgprPool.size(), self.agprPool.size(), self.doubleVgpr) + futureOccupancy = self.getOccupancy(kernel["NumThreads"], self.getLdsSize(kernel), \ + self.vgprPool.size() - numVgprAvailable + minNeeded, self.agprPool.size(), self.doubleVgpr) + + if shrinkDb: + print("currentOccupancy=%u futureOccupancy=%u VGPRs=%u numVgprAvail=%u vgprPerElem=%u" \ + % (currentOccupancy, futureOccupancy, self.vgprPool.size(), \ + numVgprAvailable, minElements*numVgprsPerElement)) + if futureOccupancy > currentOccupancy: + if shrinkDb: + print("warning: %s growing VGPR for GlobalWrite batching - this may bloat VGPR usage" % \ + (self.kernelName)) + print(" numVgprAvailable=", numVgprAvailable, \ + "numVgprsPerElement=", numVgprsPerElement, \ + "gwvw=", gwvw) + elif gwvw != gwvwOrig: + self.ss.gwvw = gwvw # make both representations consistent + if shrinkDb: + print2("info: %s shrank gwvw from %u to %u but kept occupancy same=%u." \ + % (self.kernelName, gwvwOrig, gwvw, currentOccupancy)) + + if numVgprAvailable < minElements*numVgprsPerElement: + print2("info: growing pool += %d * %d for GlobalWrite\n" \ + % (minElements,numVgprsPerElement)) + print2(self.vgprPool.state()) + tl = [] + for i in range(0,minElements): + tl.append(self.vgprPool.checkOut(numVgprsPerElement, "grow-pool for GlobalWrite")) for t in tl: self.vgprPool.checkIn(t) - align = 1 - # align adjustment - if self.ss.cfg.numVgprsPerAddr > 1: - align = max(align, self.ss.cfg.numVgprsPerAddr) - if self.ss.cfg.numVgprPerValuC*gwvw > 1: - align = max(align, self.ss.cfg.numVgprPerValuC*gwvw) - if int(ceil(self.ss.cfg.numVgprsPerDataPerVI * gwvw)) > 1: - align = max(align, int(ceil(self.ss.cfg.numVgprsPerDataPerVI * gwvw))) - numVgprAvailable = self.vgprPool.availableBlock(numVgprsPerElement, align) - - # Grow the register pool if needed - we need enough regs for at least one element - # Unfortunate since this means the write logic is setting the VGPR requirement - # for the entire kernel but at least we have a functional kernel. - # Before growing the pool, see if we can shrink the write vector width instead? - # TODO : the vgprSerial is needed for-ever and if we grow here will split the - # range of the tmps. Maybe want to move vgprSerial to first vgpr? - - # TODO: Minimum elems for StoreRemap - # TODO: Which of DataType or DestDataType is in a better sense? 0114: Check Using DestDataType + HSS - minElements = 1 - if kernel["ProblemType"]["DataType"].isHalf() or kernel["ProblemType"]["DataType"].isBFloat16(): - minElements = 2 - elif kernel["ProblemType"]["DataType"].is8bitFloat(): - minElements = 4 - - minNeeded = minElements * numVgprsPerElement - shrinkDb = 0 + numVgprAvailable = self.vgprPool.available() + print2(self.vgprPool.state()) + + # print("NumVgprAvailable", numVgprAvailable) + if numVgprsPerElement: + numElementsPerBatch = numVgprAvailable // numVgprsPerElement + else: + numElementsPerBatch = len(elements[edgeI]) # max, do 'em all + + assert(self.numVgprValuC % gwvw == 0) # sanity check + + numElementsPerBatch = numElementsPerBatch if not kernel["NumElementsPerBatchStore"] else min(kernel["NumElementsPerBatchStore"],numElementsPerBatch) + + if shrinkDb: + print("NumElementsPerBatch=", numElementsPerBatch, "LimitedBySgprs=", self.ss.cfg.numElementsPerBatchLimitedBySgprs, \ + "WARNING" if self.ss.cfg.numElementsPerBatchLimitedBySgprs < numElementsPerBatch else "okay") + if self.ss.cfg.numElementsPerBatchLimitedBySgprs < numElementsPerBatch: + numElementsPerBatch = self.ss.cfg.numElementsPerBatchLimitedBySgprs + + # TODO: Which of DataType or DestDataType is in a better sense? 0114: Check Using DestDataType + HSS + if (kernel["ProblemType"]["DataType"].isHalf() or kernel["ProblemType"]["DataType"].isBFloat16()): + # only do an even number of halves - since these share hi/lo pieces of some registers? + if numElementsPerBatch > 1: + numElementsPerBatch = int(numElementsPerBatch/2)*2 + elif not kernel["EnableMatrixInstruction"]: + # (excluding MFMA+LSU case. It can work without an issue) + # The globalWriteBatch routine below can't handle odd elements per batch + # and 0 elements per batch is illegal. + # so if we don't have *GPR resources to handle a larger batch then need + # to mark overflowedResources rather than generate a kernel that won't work. + # It might be possible to fix globalWriteBatch to handle this case but these + # are likely to be low-performing so likely not worth optimizing. + if shrinkDb: + print("WARNING: half requires at least two elements per batch") + self.overflowedResources = 3 + #elif kernel["ProblemType"]["DataType"].is8bitFloat(): + # if numElementsPerBatch > 1: + # numElementsPerBatch = int(numElementsPerBatch/4)*4 + + assert numElementsPerBatch > 0, "numElementsPerBatch=0 for %s"%self.kernelName + + # if no atomics and no edge, then write whole vectors + # ERROR commented out in globalWriteELements, causes numVectorsPerBatch to not be int + # if not edge: # not atomic and + # numVectorsPerBatch = numElementsPerBatch / kernel["GlobalWriteVectorWidth"] + # #print " NumVectorsPerBatch", numVectorsPerBatch + # numElementsPerBatch = numVectorsPerBatch * kernel["GlobalWriteVectorWidth"] + numBatches = max(1, ceil_divide(len(elements[edgeI]),numElementsPerBatch)) + + numSgprs = self.ss.cfg.numTempSgprPerBatch + self.ss.cfg.numMaskSgprPerBatch + self.ss.cfg.numMaskSgprPerElement * numElementsPerBatch + + if self.db["PrintStoreRegisterDb"]: + print("edgeI", edgeI, "NumBatches", numBatches, "NumElementsPerBatch", numElementsPerBatch, "numVgprsPerElement", numVgprsPerElement, "len(elements[edgeI])", len(elements[edgeI])) + print ("numSgprs=", numSgprs, "sgprPool.size()=", self.sgprPool.size(), "numTempSgprPerBatch=", self.ss.cfg.numTempSgprPerBatch, + "numMaskSgprPerBatch=", self.ss.cfg.numMaskSgprPerBatch, "numMaskSgprPerElement=", self.ss.cfg.numMaskSgprPerElement) + print(self.sgprPool.state()) + kStr += self.comment("edge=%d, allocate %u sgpr. perBatchTmpS=%u perBatchMaskS=%u perElementMaskS=%u elementsPerBatch=%u" % + (edgeI, numSgprs, self.ss.cfg.numTempSgprPerBatch, self.ss.cfg.numMaskSgprPerBatch, self.ss.cfg.numMaskSgprPerElement, numElementsPerBatch)) + #kStr += "// storeStats, %d, %d, %d\n"% (edgeI, numSgprs, numElementsPerBatch) + # so if we don't have *GPR resources to handle a larger batch then need + # to mark overflowedResources rather than generate a kernel that won't work. + + tmpSgprRef = self.getTmpSgpr(numSgprs, 2) + tmpSgpr = tmpSgprRef.idx() + + elementSgprs = tmpSgpr + self.ss.cfg.numTempSgprPerBatch + + codeAccVgprRead = deepcopy(self.codeAccVgprRead) if self.serializedStore else None + codeAccVgprWrite = deepcopy(self.codeAccVgprWrite) if self.serializedStore else None + + kStr += self.computeWorkspaceSrd(kernel, sgpr(sCtaIdx), tmpSgpr) + + for batchIdx in range(0, numBatches): + elementStartIdx = batchIdx * numElementsPerBatch + elementStopIdx = min( elementStartIdx + numElementsPerBatch, len(elements[edgeI]) ) + elementsThisBatch = elements[edgeI][elementStartIdx:elementStopIdx] + #print("BATCH[%u/%u]: elements[edgeI][%u:%u] VGPRs=%u" % (batchIdx, numBatches, elementStartIdx, elementStopIdx,numVgprsPerElement )) + # elementVgprs can be large and should be perfectly tuned to the number of available + # VGPRS. We do not want to accidentally overflow and grow the pool here: + + kStr += self.fixupBatch(kernel, self.ss, batchIdx, edge, gwvw, \ + elementsThisBatch, self.coord0, self.coord1, self.addrD, self.addrC, \ + tmpVgpr, tmpCVTVgpr, \ + elementSgprs, tmpSgpr, codeAccVgprRead, codeAccVgprWrite) + # delay PreLoopVmcntCase code after globalWrite + if self.canOptimizePreLoopLWVmcnt: + kStr += PreLoopVmcntCaseStr + + del self.ss + + # Finish one write path, reset currPreLoopVmcntCase to Undefined + self.currPreLoopVmcntCase = PreLoopVmcntCase.Undefined + + # kStr += inst("s_branch", skStoreLabel, "jump to store") + + return kStr + + def writePartials(self, kernel, vectorWidths, elements, edges, atomic, tmpVgpr, tmpCVTVgpr, isOptNLL, endLabel): + kStr = "" + + partialsLabels = {} + for edge in edges: + partialsLabels[edge] = self.getNamedLabelUnique("GW_Partials_E%u" % ( 1 if edge else 0) ) + + if False in edges and True in edges: + kStr += self.checkIsEdge(kernel, "%s" % partialsLabels[True]) + + for edge in edges: + kStr += "%s:%s"%(partialsLabels[edge], self.endLine) + kStr += self.computeWorkspaceSrd(kernel, sgpr("StreamKIdx")) + kStr += self.partialsWriteProcedure(kernel, vectorWidths, elements, False, False, edge, atomic, tmpVgpr, tmpCVTVgpr, isOptNLL, endLabel) + + return kStr + + + ############################################################################## + # Partials Write Procedure + ############################################################################## + def partialsWriteProcedure(self, kernel, vectorWidths, elements, alpha, beta, edge, atomic, tmpVgpr, tmpCVTVgpr, isOptNLL, endLabel): + kStr = "" + + PreLoopVmcntCaseStr = "" + # not generate Case 2 if StoreCInUnroll with StoreVectorWidth==1 (Case 2 will be same as Case 3) + if self.canOptimizePreLoopLWVmcnt: + if beta: + self.currPreLoopVmcntCase = PreLoopVmcntCase.OrdNLL_B1_Store + elif edge or (kernel["StoreCInUnroll"] and kernel["StoreVectorWidth"]==1): + self.currPreLoopVmcntCase = PreLoopVmcntCase.OrdNLL_E1_Store + else: + self.currPreLoopVmcntCase = PreLoopVmcntCase.OptNLL_Store + PreLoopVmcntCaseStr = inst("s_mov_b32", sgpr("PreLoopLWVmcntCase"), hex(self.currPreLoopVmcntCase.value), \ + "for optimizing next PreLoop LW vmcnt, set to Case%u"%self.currPreLoopVmcntCase.value) + # reset vmcnt if the dict has this key (OptNLL_Store, OrdNLL_E1_Store), + # OrdNLL_B1_Store is excluded + if self.currPreLoopVmcntCase in self.preLoopVmcntDict: + self.preLoopVmcntDict[self.currPreLoopVmcntCase] = 0 + + edgeI = edge + #edgeI = True # set to True to disable vector stores + gwvw = vectorWidths[edgeI] + #print "globalWriteElements: edge=", edge, "beta=", beta, "atomic=", atomic + + ######################################## + # Calculate Vgprs for Write Batching + ######################################## + + self.ss = self.StoreState(self, kernel, gwvw, edge, beta, atomic, elements[edgeI], isWorkspace=True) + + # how many vgprs are needed for zero elements + # 2 for addressC in vgpr for addition - already checked out + # 2 for coord0,1 of thread - already checked out + # 2 for tmp - already checked out + + # 5 = how many vgprs are needed per element (flat) + # - 2 for addr + # - 3 for GLOBAL_OFFSET_C calculation (can overlap below, therefore max) + # - if beta gwvw*rpe for new value + # - if atomic 2*rpe for old and cmp values + + # print("numVgprsPerAddr=%u, numVgprsPerDataPerVI=%u, numVgprPerValuC=%u"%(self.ss.cfg.numVgprsPerAddr, self.ss.cfg.numVgprsPerDataPerVI, self.ss.cfg.numVgprPerValuC)) + numVgprsPerElement = self.ss.cfg.numVgprPerValuC*gwvw + self.ss.cfg.numVgprsPerAddr + int(ceil(self.ss.cfg.numVgprsPerDataPerVI * gwvw)) + + if kernel["GroupLoadStore"] and kernel["ProblemType"]["UseBeta"]: + numVgprsPerElement += self.ss.cfg.numVgprsPerAddr + + #print self.vgprPool.state() + # Use VGPR up to next occupancy threshold: + maxVgprs = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), \ + self.getLdsSize(kernel), self.agprPool.size(), self.doubleVgpr) + if self.serializedStore: # get aggressive when serializedStore is on; not necessarily exclusive to this parameter + len(elements[edgeI]) + tl = [] + for i in range(self.vgprPool.size()-self.vgprPool.available(), maxVgprs): + tl.append(self.vgprPool.checkOut(1, "grow-pool up to next occupancy for GlobalWrite")) + for t in tl: + self.vgprPool.checkIn(t) + align = 1 + # align adjustment + if self.ss.cfg.numVgprsPerAddr > 1: + align = max(align, self.ss.cfg.numVgprsPerAddr) + if self.ss.cfg.numVgprPerValuC*gwvw > 1: + align = max(align, self.ss.cfg.numVgprPerValuC*gwvw) + if int(ceil(self.ss.cfg.numVgprsPerDataPerVI * gwvw)) > 1: + align = max(align, int(ceil(self.ss.cfg.numVgprsPerDataPerVI * gwvw))) + numVgprAvailable = self.vgprPool.availableBlock(numVgprsPerElement, align) + + # Grow the register pool if needed - we need enough regs for at least one element + # Unfortunate since this means the write logic is setting the VGPR requirement + # for the entire kernel but at least we have a functional kernel. + # Before growing the pool, see if we can shrink the write vector width instead? + # TODO : the vgprSerial is needed for-ever and if we grow here will split the + # range of the tmps. Maybe want to move vgprSerial to first vgpr? + + # TODO: Minimum elems for StoreRemap + # TODO: Which of DataType or DestDataType is in a better sense? 0114: Check Using DestDataType + HSS + minElements = 1 + if kernel["ProblemType"]["DataType"].isHalf() or kernel["ProblemType"]["DataType"].isBFloat16(): + minElements = 2 + elif kernel["ProblemType"]["DataType"].is8bitFloat(): + minElements = 4 + + minNeeded = minElements * numVgprsPerElement + shrinkDb = 0 + if shrinkDb: + print("numVgprAvailable=", numVgprAvailable, "minElements=", minElements, "minNeeded=", minNeeded) + if numVgprAvailable < minNeeded: + gwvwOrig = gwvw + currentOccupancy = self.getOccupancy(kernel["NumThreads"], self.getLdsSize(kernel), \ + self.vgprPool.size(), self.agprPool.size(), self.doubleVgpr) + futureOccupancy = self.getOccupancy(kernel["NumThreads"], self.getLdsSize(kernel), \ + self.vgprPool.size() - numVgprAvailable + minNeeded, self.agprPool.size(), self.doubleVgpr) + + if shrinkDb: + print("currentOccupancy=%u futureOccupancy=%u VGPRs=%u numVgprAvail=%u vgprPerElem=%u" \ + % (currentOccupancy, futureOccupancy, self.vgprPool.size(), \ + numVgprAvailable, minElements*numVgprsPerElement)) + if futureOccupancy > currentOccupancy: + if shrinkDb: + print("warning: %s growing VGPR for GlobalWrite batching - this may bloat VGPR usage" % \ + (self.kernelName)) + print(" numVgprAvailable=", numVgprAvailable, \ + "numVgprsPerElement=", numVgprsPerElement, "atomic=", atomic, \ + "beta=", beta, "gwvw=", gwvw) + elif gwvw != gwvwOrig: + self.ss.gwvw = gwvw # make both representations consistent + if shrinkDb: + print2("info: %s shrank gwvw from %u to %u but kept occupancy same=%u." \ + % (self.kernelName, gwvwOrig, gwvw, currentOccupancy)) + + if numVgprAvailable < minElements*numVgprsPerElement: + print2("info: growing pool += %d * %d for GlobalWrite\n" \ + % (minElements,numVgprsPerElement)) + print2(self.vgprPool.state()) + tl = [] + for i in range(0,minElements): + tl.append(self.vgprPool.checkOut(numVgprsPerElement, "grow-pool for GlobalWrite")) + for t in tl: + self.vgprPool.checkIn(t) + numVgprAvailable = self.vgprPool.available() + print2(self.vgprPool.state()) + + # set atomicW after we potentially resize GWVW + atomicW = min(gwvw, kernel["VectorAtomicWidth"]) + + # print("NumVgprAvailable", numVgprAvailable) + if numVgprsPerElement: + numElementsPerBatch = numVgprAvailable // numVgprsPerElement + else: + numElementsPerBatch = len(elements[edgeI]) # max, do 'em all + + assert(self.numVgprValuC % gwvw == 0) # sanity check + + numElementsPerBatch = numElementsPerBatch if not kernel["NumElementsPerBatchStore"] else min(kernel["NumElementsPerBatchStore"],numElementsPerBatch) + + if shrinkDb: + print("NumElementsPerBatch=", numElementsPerBatch, "LimitedBySgprs=", self.ss.cfg.numElementsPerBatchLimitedBySgprs, \ + "WARNING" if self.ss.cfg.numElementsPerBatchLimitedBySgprs < numElementsPerBatch else "okay") + if self.ss.cfg.numElementsPerBatchLimitedBySgprs < numElementsPerBatch: + numElementsPerBatch = self.ss.cfg.numElementsPerBatchLimitedBySgprs + + # TODO: Which of DataType or DestDataType is in a better sense? 0114: Check Using DestDataType + HSS + if (kernel["ProblemType"]["DataType"].isHalf() or kernel["ProblemType"]["DataType"].isBFloat16()): + # only do an even number of halves - since these share hi/lo pieces of some registers? + if numElementsPerBatch > 1: + numElementsPerBatch = int(numElementsPerBatch/2)*2 + elif not kernel["EnableMatrixInstruction"]: + # (excluding MFMA+LSU case. It can work without an issue) + # The globalWriteBatch routine below can't handle odd elements per batch + # and 0 elements per batch is illegal. + # so if we don't have *GPR resources to handle a larger batch then need + # to mark overflowedResources rather than generate a kernel that won't work. + # It might be possible to fix globalWriteBatch to handle this case but these + # are likely to be low-performing so likely not worth optimizing. if shrinkDb: - print("numVgprAvailable=", numVgprAvailable, "minElements=", minElements, "minNeeded=", minNeeded) - if numVgprAvailable < minNeeded: - gwvwOrig = gwvw - currentOccupancy = self.getOccupancy(kernel["NumThreads"], self.getLdsSize(kernel), \ - self.vgprPool.size(), self.agprPool.size(), self.doubleVgpr) - futureOccupancy = self.getOccupancy(kernel["NumThreads"], self.getLdsSize(kernel), \ - self.vgprPool.size() - numVgprAvailable + minNeeded, self.agprPool.size(), self.doubleVgpr) + print("WARNING: half requires at least two elements per batch") + self.overflowedResources = 3 + #elif kernel["ProblemType"]["DataType"].is8bitFloat(): + # if numElementsPerBatch > 1: + # numElementsPerBatch = int(numElementsPerBatch/4)*4 + + assert numElementsPerBatch > 0, "numElementsPerBatch=0 for %s"%self.kernelName + + #numElementsPerBatch=min(2,numElementsPerBatch) # hack to control number of batches + if atomic and (self.ss.optSingleColVgpr or self.ss.optSharedColVgpr): + # hack to avoid re-using address vgpr across rows + # atomics need to perform several memory operations + # if the batch spans multiple rows, need multiple address vgpr + # which is not currently supported in the two opt*ColVgpr modes + firstRow = [e for e in elements[edgeI] if e[0]==0 and e[2]==0] + numElementsPerBatch=min(len(firstRow),numElementsPerBatch) + + numBatches = max(1, ceil_divide(len(elements[edgeI]),numElementsPerBatch)) + + numSgprs = self.ss.cfg.numTempSgprPerBatch + self.ss.cfg.numMaskSgprPerBatch + self.ss.cfg.numMaskSgprPerElement * numElementsPerBatch + + if self.db["PrintStoreRegisterDb"]: + print("edgeI", edgeI, "NumBatches", numBatches, "NumElementsPerBatch", numElementsPerBatch, "numVgprsPerElement", numVgprsPerElement, "len(elements[edgeI])", len(elements[edgeI])) + print ("numSgprs=", numSgprs, "sgprPool.size()=", self.sgprPool.size(), "numTempSgprPerBatch=", self.ss.cfg.numTempSgprPerBatch, + "numMaskSgprPerBatch=", self.ss.cfg.numMaskSgprPerBatch, "numMaskSgprPerElement=", self.ss.cfg.numMaskSgprPerElement) + print(self.sgprPool.state()) + kStr += self.comment("edge=%d, allocate %u sgpr. perBatchTmpS=%u perBatchMaskS=%u perElementMaskS=%u elementsPerBatch=%u" % + (edgeI, numSgprs, self.ss.cfg.numTempSgprPerBatch, self.ss.cfg.numMaskSgprPerBatch, self.ss.cfg.numMaskSgprPerElement, numElementsPerBatch)) + #kStr += "// storeStats, %d, %d, %d\n"% (edgeI, numSgprs, numElementsPerBatch) + # so if we don't have *GPR resources to handle a larger batch then need + # to mark overflowedResources rather than generate a kernel that won't work. + tmpSgpr = self.getTmpSgpr(numSgprs, 2).idx() + + elementSgprs = tmpSgpr + self.ss.cfg.numTempSgprPerBatch + + codeAccVgprRead = deepcopy(self.codeAccVgprRead) if self.serializedStore else None + codeMulAlpha = deepcopy(self.codeMulAlpha) if self.serializedStore else None + + self.alphaBeforeLoadC = False + useCodeMulAlpha = kernel["MIArchVgpr"] and alpha and not (kernel["GlobalSplitU"] > 1) + if useCodeMulAlpha: # do not set codeAccVgprRead=None if GSU>1 + codeAccVgprRead = None + + #Only apply when 2 wave optimization features are enabled + if (kernel["StorePriorityOpt"] or kernel["StoreSyncOpt"]) and beta: + self.alphaBeforeLoadC = True + else: + codeMulAlpha = None + + for batchIdx in range(0, numBatches): + elementStartIdx = batchIdx * numElementsPerBatch + elementStopIdx = min( elementStartIdx + numElementsPerBatch, len(elements[edgeI]) ) + elementsThisBatch = elements[edgeI][elementStartIdx:elementStopIdx] + #print("BATCH[%u/%u]: elements[edgeI][%u:%u] VGPRs=%u" % (batchIdx, numBatches, elementStartIdx, elementStopIdx,numVgprsPerElement )) + # elementVgprs can be large and should be perfectly tuned to the number of available + # VGPRS. We do not want to accidentally overflow and grow the pool here: + + kStr += self.partialsWriteBatch(kernel, self.ss, batchIdx, alpha, beta, edge, atomic, gwvw, atomicW, \ + elementsThisBatch, self.coord0, self.coord1, self.addrD, self.addrC, \ + tmpVgpr, tmpCVTVgpr, \ + elementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, isOptNLL) + # delay PreLoopVmcntCase code after globalWrite + if self.canOptimizePreLoopLWVmcnt: + kStr += PreLoopVmcntCaseStr + + # Set flag + kStr += inst("s_waitcnt", "vmcnt(0)", "wait for data store") + kStr += inst("s_barrier", "store all data before setting flag") + kStr += inst("s_lshl_b32", sgpr(tmpSgpr), sgpr("StreamKIdx"), log2(4), "flag offset based on CTA index") + kStr += inst("s_mov_b32", sgpr(tmpSgpr+2), 1, "flag data") + kStr += inst("s_store_dword", sgpr(tmpSgpr+2), sgpr("AddressFlags", 2), sgpr(tmpSgpr), "glc", "set flag") + kStr += inst("s_waitcnt", "lgkmcnt(0)", "wait for flag") # TODO just for testing + + # TODO - if this is the last tile, don't need to jump to next instruction + # NOTE: in SR kernel, we need long branch since PRNG explodes the line of codes + if kernel["ProblemType"]["StochasticRounding"]: # in-device RND + endLabelName = "label_%s"%(endLabel) + kStr += self.longBranchPositive(endLabelName) + else: + kStr += inst("s_branch", "label_%s"%endLabel, "jump to end") + del self.ss + + # Finish one write path, reset currPreLoopVmcntCase to Undefined + self.currPreLoopVmcntCase = PreLoopVmcntCase.Undefined + + return kStr + + + ############################################################################## + # Partials Write Batch + ############################################################################## + def partialsWriteBatch(self, kernel, ss, batchIdx, applyAlpha, beta, edge, atomic, gwvw, atomicW, \ + batchElements, coord0, coord1, addrD, addrC, \ + tmpVgpr, tmpCVTVgpr, batchElementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, isOptNLL): + kStr = "" + + kStr += self.comment1("optSingleColVgpr=%u optSharedColVgpr=%u optSGPRUsage=%s optSrdIncForRow=%u" % \ + (ss.optSingleColVgpr, ss.optSharedColVgpr, ss.optSGPRUsage, ss.optSrdIncForRow)) + + if kernel["StoreSyncOpt"]: + kStr += "s_sleep %d // optimization: sync and wait\n" %(kernel["StoreSyncOpt"]-1) + kStr += "s_barrier\n" + + # comment tt1, tt0, vc1, vc0 + # tt = thread tile, vc=vector component + commentStr = "Partials Write%s%s%s Batch #%u (d1,d0,vc1,vc0) =\n " \ + % (" Alpha" if applyAlpha else "", " Beta" if beta else "", " Edge" if edge else "", batchIdx) + for elementIdx in range(0, len(batchElements)): + element = batchElements[elementIdx] + commentStr += "(%u,%u,%u,%u:vw%u%s)" % \ + (element[0], element[1], element[2], element[3], gwvw, + ":vaw:%u"%atomicW if atomic else "") + if elementIdx < len(batchElements)-1: + commentStr += "; " + kStr += self.comment3(commentStr) - if shrinkDb: - print("currentOccupancy=%u futureOccupancy=%u VGPRs=%u numVgprAvail=%u vgprPerElem=%u" \ - % (currentOccupancy, futureOccupancy, self.vgprPool.size(), \ - numVgprAvailable, minElements*numVgprsPerElement)) - if futureOccupancy > currentOccupancy: - if shrinkDb: - print("warning: %s growing VGPR for GlobalWrite batching - this may bloat VGPR usage" % \ - (self.kernelName)) - print(" numVgprAvailable=", numVgprAvailable, \ - "numVgprsPerElement=", numVgprsPerElement, "atomic=", atomic, \ - "beta=", beta, "gwvw=", gwvw) - elif gwvw != gwvwOrig: - self.ss.gwvw = gwvw # make both representations consistent - if shrinkDb: - print2("info: %s shrank gwvw from %u to %u but kept occupancy same=%u." \ - % (self.kernelName, gwvwOrig, gwvw, currentOccupancy)) - - if numVgprAvailable < minElements*numVgprsPerElement: - print2("info: growing pool += %d * %d for GlobalWrite\n" \ - % (minElements,numVgprsPerElement)) - print2(self.vgprPool.state()) - tl = [] - for i in range(0,minElements): - tl.append(self.vgprPool.checkOut(numVgprsPerElement, "grow-pool for GlobalWrite")) - for t in tl: - self.vgprPool.checkIn(t) - numVgprAvailable = self.vgprPool.available() - print2(self.vgprPool.state()) - - # set atomicW after we potentially resize GWVW - atomicW = min(gwvw, kernel["VectorAtomicWidth"]) - - # print("NumVgprAvailable", numVgprAvailable) - if numVgprsPerElement: - numElementsPerBatch = numVgprAvailable // numVgprsPerElement - else: - numElementsPerBatch = len(elements[edgeI]) # max, do 'em all + # allow expanding vgpr pool for OptNLL + preventOverflow = (not isOptNLL) + ss.setupStoreElementsForBatch(kernel, gwvw, batchElements, batchElementSgprs, preventOverflow=preventOverflow, \ + VectorWidthB=self.VectorWidthB, isWorkspace=True) + + storesIssued = 0 + tmpS01 = tmpSgpr # scratch sgprs + + ######################################## + # calculate addr and masks + kStr += self.comment("calc coords, apply mask, and issue loads (if necessary)") + # On input, coord0 and coord1 are VGPRs computed in the pre-batch code, based + # on the thread and tid number. These are ELEMENT offsets from start of tensor C + # for the top-left corner this thread will write. These are not changed + # across all the store loop iters. + if self.db["ConservativeWaitCnt"] & 0x10: + kStr += "s_barrier // debug\n" + kStr += inst("s_waitcnt", "vmcnt(0)", "ConservativeWaitCnt" ) + if self.archCaps["SeparateVscnt"]: + kStr += inst("s_waitcnt_vscnt", "null", "0", "writes") + kStr += "s_barrier // debug\n" + if not edge and self.db["ForceEdgeStores"]>=2: + kStr += self.bomb() # should not get here + if edge and self.db["AssertNoEdge"]: + kStr += self.bomb() # should not get here + + ## create code Module to push mov vgpr,acc instructions + if kernel["StoreCInUnroll"] and not edge: + accVgprRead = Code.Module("movaccVgpr") + self.StoreCUnrollLoadCWaitComment = "waitcnt for LoadC" # this will be used later to identify waitcnt for loadC + + ######################################## + # AccVgpr read + if kernel.enabledSetPrioSplitLDS: + kStr += inst("s_setprio", "0", "") + if codeAccVgprRead is not None: + regsPerScalar = self.bpeCinternal//self.bpr # register per scalar + # loop over store instructions within one batch + for elementIdx in range(0, len(batchElements)): + # loop over scalars within one store instruction + for vi in range(0, gwvw): + # loop over registers within one scalar + for rIdx in range(0, regsPerScalar): + tempStr = str(codeAccVgprRead.items().pop(0)) + kStr += tempStr.replace("__placeholder__", str(ss.elementSumIdx[elementIdx]*regsPerScalar + regsPerScalar*vi + rIdx)) + if kernel["StoreCInUnroll"] and not edge: + tempStr = tempStr.replace("__placeholder__",str(elementIdx*gwvw*regsPerScalar + regsPerScalar*vi + rIdx)) + accVgprRead.addCode(tempStr.replace("ValuC","L2GC")) + + if not kernel["MIArchVgpr"]: + kStr += inst("s_nop 1", "2 wait states required before reading vgpr") + + ######################################## + # Not Atomic + ######################################## + # else: + # edge has v_cndmask so loads or stores may not issue, hard to track vmcnt: + for elementIdx in range(0, len(batchElements)): + for vi in range(0, gwvw): + sumIdxV = ss.elementSumIdx[elementIdx] + vi + # covers sgemm, gemm_ex(HHS/HSS/BBS/BSS (HPA=T)), int8 (int8x4?) + if kernel["ProblemType"]["ComputeDataType"].isInt32() or \ + kernel["ProblemType"]["ComputeDataType"].isSingle(): # covers sgemm/gemm_ex(HHS/HSS/BBS/BSS) + if self.db["ForceExpectedValue"]: + kStr += inst("v_mov_b32", vgpr("ValuC+%u"%sumIdxV), self.db["ValueCExpectedValue"], "force expected value" ) + if self.db["ForceVSerial"]: + kStr += inst("v_mov_b32", vgpr("ValuC+%u"%sumIdxV), vgpr("Serial"), "force expected value to serial" ) + if self.db["CheckValueC"]: + kStr += inst("s_mov_b32", sgpr(tmpS01), self.db["ValueCExpectedValue"], "Move expected value") + kStr += self.assert_eq(vgpr("ValuC+%u"%sumIdxV), sgpr(tmpS01)) + + kStr += self.comment("apply mask, calc new C and issue writes") + #kStr += self.bomb() # can see store addresses just before the store inst + + if kernel["ProblemType"]["DestDataType"].isBFloat16() and kernel["ProblemType"]["HighPrecisionAccumulate"]: + vgprBf16Temp = tmpCVTVgpr + vgprBf16Mask = vgprBf16Temp + 1 + vgprFp32Nan = vgprBf16Temp + 2 + vgprBf16Inc = vgprBf16Temp + 3 + kStr += inst("v_mov_b32", vgpr(vgprBf16Mask), "0xffff0000", "mask for pack two bfloat16 element to 32bit" ) + kStr += inst("v_mov_b32", vgpr(vgprFp32Nan), "0x7fff0000", "fp32 Nan" ) + kStr += inst("v_mov_b32", vgpr(vgprBf16Inc), "0x7fff", "rounding bias for bfloat16" ) + + # DestDataType for 8bit Float can only be F8 or B8 + if kernel["ProblemType"]["DestDataType"].isFloat8() or kernel["ProblemType"]["DestDataType"].isBFloat8(): # F8 is always HPA + # make vgprF8Temp0 always even to use pk instruction later + if tmpCVTVgpr % 2 == 0: + vgprF8Temp0 = tmpCVTVgpr + vgprF8Max = vgprF8Temp0 + 2 + vgprF8Min = vgprF8Temp0 + 3 + else: + vgprF8Max = tmpCVTVgpr + vgprF8Temp0 = vgprF8Max + 1 + vgprF8Min = vgprF8Max + 3 + + if kernel["ProblemType"]["Fp32toFp8SWClip"]: + # set flag of f32 NaN and +/- INF for v_cmp_class + vgprFp32NanInfFlag = vgprF8Min + 1 + kStr += inst("v_mov_b32", vgpr(vgprFp32NanInfFlag), "0x207", "flag for Nan and +/- inf" ) + # set max/min values for clipping + if kernel["ProblemType"]["DestDataType"].isFloat8(): + kStr += inst("v_mov_b32", vgpr(vgprF8Max), "0x43700000", "save 240.0f as max for clipping" ) + kStr += inst("v_mov_b32", vgpr(vgprF8Min), "0xC3700000", "save -240.0f as min for clipping" ) + else: #BFloat8 + kStr += inst("v_mov_b32", vgpr(vgprF8Max), "0x47600000", "save 57344.0f as max for clipping" ) + kStr += inst("v_mov_b32", vgpr(vgprF8Min), "0xC7600000", "save -57344`.0f as min for clipping" ) + + storeCode = "" + for elementIdx in range(0, len(batchElements)): + element = batchElements[elementIdx] + addr = ss.elementAddr[elementIdx].addrDVgpr + addrCalc = ss.elementAddr[elementIdx] + sumIdx = ss.elementSumIdx[elementIdx] + + storeWidth = kernel["StoreVectorWidth"] + # storeWidth = 2 + if batchIdx == 0 and elementIdx == 0: + kStr += staticMultiply(vgpr(addr), vgpr("Serial"), storeWidth * self.bpeCinternal, sgpr(tmpS01)) + # kStr += inst("v_mul_lo_u32", , "Partials buffer address") + kStr += inst("s_mov_b32", sgpr(tmpS01), 0, "Init sgpr offset") + else: + increment = (kernel["WavefrontSize"] * 4) * storeWidth * self.bpeCinternal + kStr += inst("s_add_u32", sgpr(tmpS01), sgpr(tmpS01), increment, "Inc sgpr offset") + + # TODO StreamK need this packing code??? + # if self.asmCaps["HasWMMA"] and kernel["EnableMatrixInstructionStore"] and kernel["ProblemType"]["DestDataType"].isHalf() and (not kernel["ProblemType"]["HighPrecisionAccumulate"]): + # for vi in range(0, gwvw): + # sumIdxV = ss.elementSumIdx[elementIdx] + vi + # if vi%2 == 1: + # d = ss.elementSumIdx[elementIdx] + vi//2 + # kStr += inst("v_pack_b32_f16", vgpr(d), vgpr("ValuC+%u"%(sumIdxV-1)), vgpr("ValuC+%u"%sumIdxV), "Pack with neighbor" ) + + # if not kernel["StoreRemapVectorWidth"]: + tmpStoreCode = self.addStore(kernel, ss, addrCalc, sumIdx, tmpS01, edge, 'WS', sgpr(tmpS01)) + if kernel["GroupLoadStore"]: + storeCode += tmpStoreCode + else: + kStr += tmpStoreCode + storesIssued += 1 + + kStr += storeCode + + # return registers to pool: + lastData = -1 + for elementIdx in range(0, len(batchElements)): + if not ss.sharedColDVgprs: + addrDVgpr = ss.elementAddr[elementIdx].addrDVgpr + addrCVgpr = ss.elementAddr[elementIdx].addrCVgpr + self.vgprPool.checkIn(addrDVgpr) + if addrCVgpr != addrDVgpr: + self.vgprPool.checkIn(addrCVgpr) + + data = ss.elementData[elementIdx] + if data != 0: + if data != lastData: + self.vgprPool.checkIn(data) + lastData = data + + self.ss.firstBatch = False + self.ss.checkInTempVgprC() + + if self.serializedStore: + kStr += inst("s_nop 0", "1 wait state required when next inst writes vgprs held by previous dwordx4 store inst") - assert(self.numVgprValuC % gwvw == 0) # sanity check + # Update the store cnt to preLoopVmcntDict for Case2/3 + # (No need to update for Case0:'Undefined' or Case4:'OrdNLL_B1_Store') + if self.currPreLoopVmcntCase in self.preLoopVmcntDict: + if not self.archCaps["SeparateVscnt"]: + self.preLoopVmcntDict[self.currPreLoopVmcntCase] += storesIssued + + return kStr + + ############################################################################## + # Global Write Procedure + ############################################################################## + def globalWriteProcedure(self, kernel, vectorWidths, elements, alpha, beta, edge, atomic, tmpVgpr, tmpCVTVgpr, isOptNLL, endLabel): + kStr = "" + + PreLoopVmcntCaseStr = "" + # not generate Case 2 if StoreCInUnroll with StoreVectorWidth==1 (Case 2 will be same as Case 3) + if self.canOptimizePreLoopLWVmcnt: + if beta: + self.currPreLoopVmcntCase = PreLoopVmcntCase.OrdNLL_B1_Store + elif edge or (kernel["StoreCInUnroll"] and kernel["StoreVectorWidth"]==1): + self.currPreLoopVmcntCase = PreLoopVmcntCase.OrdNLL_E1_Store + else: + self.currPreLoopVmcntCase = PreLoopVmcntCase.OptNLL_Store + PreLoopVmcntCaseStr = inst("s_mov_b32", sgpr("PreLoopLWVmcntCase"), hex(self.currPreLoopVmcntCase.value), \ + "for optimizing next PreLoop LW vmcnt, set to Case%u"%self.currPreLoopVmcntCase.value) + # reset vmcnt if the dict has this key (OptNLL_Store, OrdNLL_E1_Store), + # OrdNLL_B1_Store is excluded + if self.currPreLoopVmcntCase in self.preLoopVmcntDict: + self.preLoopVmcntDict[self.currPreLoopVmcntCase] = 0 + + # for storeRemap edge case, non-beta still can enable vector stores + if kernel["StoreRemapVectorWidth"] and not beta: + edgeI = False + else: + edgeI = edge + #edgeI = True # set to True to disable vector stores + # print(edgeI) + # print(vectorWidths) + gwvw = vectorWidths[edgeI] + #print "globalWriteElements: edge=", edge, "beta=", beta, "atomic=", atomic + + ######################################## + # Calculate Vgprs for Write Batching + ######################################## + + self.ss = self.StoreState(self, kernel, gwvw, edge, beta, atomic, elements[edgeI]) + + # how many vgprs are needed for zero elements + # 2 for addressC in vgpr for addition - already checked out + # 2 for coord0,1 of thread - already checked out + # 2 for tmp - already checked out + + # 5 = how many vgprs are needed per element (flat) + # - 2 for addr + # - 3 for GLOBAL_OFFSET_C calculation (can overlap below, therefore max) + # - if beta gwvw*rpe for new value + # - if atomic 2*rpe for old and cmp values + + # print("numVgprsPerAddr=%u, numVgprsPerDataPerVI=%u, numVgprPerValuC=%u"%(self.ss.cfg.numVgprsPerAddr, self.ss.cfg.numVgprsPerDataPerVI, self.ss.cfg.numVgprPerValuC)) + numVgprsPerElement = self.ss.cfg.numVgprPerValuC*gwvw + self.ss.cfg.numVgprsPerAddr + int(ceil(self.ss.cfg.numVgprsPerDataPerVI * gwvw)) - numElementsPerBatch = numElementsPerBatch if not kernel["NumElementsPerBatchStore"] else min(kernel["NumElementsPerBatchStore"],numElementsPerBatch) + if kernel["GroupLoadStore"] and kernel["ProblemType"]["UseBeta"]: + numVgprsPerElement += self.ss.cfg.numVgprsPerAddr + #print self.vgprPool.state() + # Use VGPR up to next occupancy threshold: + maxVgprs = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), \ + self.getLdsSize(kernel), self.agprPool.size(), self.doubleVgpr) + if self.serializedStore: # get aggressive when serializedStore is on; not necessarily exclusive to this parameter + len(elements[edgeI]) + tl = [] + for i in range(self.vgprPool.size()-self.vgprPool.available(), maxVgprs): + tl.append(self.vgprPool.checkOut(1, "grow-pool up to next occupancy for GlobalWrite")) + for t in tl: + self.vgprPool.checkIn(t) + align = 1 + # align adjustment + if self.ss.cfg.numVgprsPerAddr > 1: + align = max(align, self.ss.cfg.numVgprsPerAddr) + if self.ss.cfg.numVgprPerValuC*gwvw > 1: + align = max(align, self.ss.cfg.numVgprPerValuC*gwvw) + if int(ceil(self.ss.cfg.numVgprsPerDataPerVI * gwvw)) > 1: + align = max(align, int(ceil(self.ss.cfg.numVgprsPerDataPerVI * gwvw))) + numVgprAvailable = self.vgprPool.availableBlock(numVgprsPerElement, align) + + # Grow the register pool if needed - we need enough regs for at least one element + # Unfortunate since this means the write logic is setting the VGPR requirement + # for the entire kernel but at least we have a functional kernel. + # Before growing the pool, see if we can shrink the write vector width instead? + # TODO : the vgprSerial is needed for-ever and if we grow here will split the + # range of the tmps. Maybe want to move vgprSerial to first vgpr? + + # TODO: Minimum elems for StoreRemap + # TODO: Which of DataType or DestDataType is in a better sense? 0114: Check Using DestDataType + HSS + minElements = 1 + if kernel["ProblemType"]["DataType"].isHalf() or kernel["ProblemType"]["DataType"].isBFloat16(): + minElements = 2 + elif kernel["ProblemType"]["DataType"].is8bitFloat(): + minElements = 4 + + minNeeded = minElements * numVgprsPerElement + shrinkDb = 0 + if shrinkDb: + print("numVgprAvailable=", numVgprAvailable, "minElements=", minElements, "minNeeded=", minNeeded) + if numVgprAvailable < minNeeded: + gwvwOrig = gwvw + currentOccupancy = self.getOccupancy(kernel["NumThreads"], self.getLdsSize(kernel), \ + self.vgprPool.size(), self.agprPool.size(), self.doubleVgpr) + futureOccupancy = self.getOccupancy(kernel["NumThreads"], self.getLdsSize(kernel), \ + self.vgprPool.size() - numVgprAvailable + minNeeded, self.agprPool.size(), self.doubleVgpr) + + if shrinkDb: + print("currentOccupancy=%u futureOccupancy=%u VGPRs=%u numVgprAvail=%u vgprPerElem=%u" \ + % (currentOccupancy, futureOccupancy, self.vgprPool.size(), \ + numVgprAvailable, minElements*numVgprsPerElement)) + if futureOccupancy > currentOccupancy: if shrinkDb: - print("NumElementsPerBatch=", numElementsPerBatch, "LimitedBySgprs=", self.ss.cfg.numElementsPerBatchLimitedBySgprs, \ - "WARNING" if self.ss.cfg.numElementsPerBatchLimitedBySgprs < numElementsPerBatch else "okay") - if self.ss.cfg.numElementsPerBatchLimitedBySgprs < numElementsPerBatch: - numElementsPerBatch = self.ss.cfg.numElementsPerBatchLimitedBySgprs - - # TODO: Which of DataType or DestDataType is in a better sense? 0114: Check Using DestDataType + HSS - if (kernel["ProblemType"]["DataType"].isHalf() or kernel["ProblemType"]["DataType"].isBFloat16()): - # only do an even number of halves - since these share hi/lo pieces of some registers? - if numElementsPerBatch > 1: - numElementsPerBatch = int(numElementsPerBatch/2)*2 - elif not kernel["EnableMatrixInstructionStore"]: - # The globalWriteBatch routine below can't handle odd elements per batch - # and 0 elements per batch is illegal. - # so if we don't have *GPR resources to handle a larger batch then need - # to mark overflowedResources rather than generate a kernel that won't work. - # It might be possible to fix globalWriteBatch to handle this case but these - # are likely to be low-performing so likely not worth optimizing. - if shrinkDb: - print("WARNING: half requires at least two elements per batch") - self.overflowedResources = 3 - #elif kernel["ProblemType"]["DataType"].is8bitFloat(): - # if numElementsPerBatch > 1: - # numElementsPerBatch = int(numElementsPerBatch/4)*4 - - assert numElementsPerBatch > 0, "numElementsPerBatch=0 for %s"%self.kernelName - - #numElementsPerBatch=min(2,numElementsPerBatch) # hack to control number of batches - if atomic and (self.ss.optSingleColVgpr or self.ss.optSharedColVgpr): - # hack to avoid re-using address vgpr across rows - # atomics need to perform several memory operations - # if the batch spans multiple rows, need multiple address vgpr - # which is not currently supported in the two opt*ColVgpr modes - firstRow = [e for e in elements[edgeI] if e[0]==0 and e[2]==0] - numElementsPerBatch=min(len(firstRow),numElementsPerBatch) - - # check best numElementsPerBatch to handle a column block - # elements of column block must be multiple size of numElementsPerBatch - if kernel["StoreRemapVectorWidth"]: - firstRow = [e for e in elements[edgeI] if e[0]==0 and e[2]==0] # format for element = (tt1, tt0, vc1, vc0) - # find the largest factor and smaller than numElementPerBatch - nBatchesPerRow = 1 - for d in range(1, len(firstRow)+1): - largestFactor = len(firstRow)//d - if len(firstRow)%d == 0 and largestFactor <= numElementsPerBatch: - numElementsPerBatch = largestFactor - nBatchesPerRow = d - break - - # if no atomics and no edge, then write whole vectors - #if not atomic and not edge: - # numVectorsPerBatch = numElementsPerBatch / kernel["GlobalWriteVectorWidth"] - # #print " NumVectorsPerBatch", numVectorsPerBatch - # numElementsPerBatch = numVectorsPerBatch * kernel["GlobalWriteVectorWidth"] - numBatches = max(1, ceil_divide(len(elements[edgeI]),numElementsPerBatch)) - - numSgprs = self.ss.cfg.numTempSgprPerBatch + self.ss.cfg.numMaskSgprPerBatch + self.ss.cfg.numMaskSgprPerElement * numElementsPerBatch - - if self.db["PrintStoreRegisterDb"]: - print("edgeI", edgeI, "NumBatches", numBatches, "NumElementsPerBatch", numElementsPerBatch, "numVgprsPerElement", numVgprsPerElement, "len(elements[edgeI])", len(elements[edgeI])) - print ("numSgprs=", numSgprs, "sgprPool.size()=", self.sgprPool.size(), "numTempSgprPerBatch=", self.ss.cfg.numTempSgprPerBatch, - "numMaskSgprPerBatch=", self.ss.cfg.numMaskSgprPerBatch, "numMaskSgprPerElement=", self.ss.cfg.numMaskSgprPerElement) - print(self.sgprPool.state()) - kStr += self.comment("edge=%d, allocate %u sgpr. perBatchTmpS=%u perBatchMaskS=%u perElementMaskS=%u elementsPerBatch=%u" % - (edgeI, numSgprs, self.ss.cfg.numTempSgprPerBatch, self.ss.cfg.numMaskSgprPerBatch, self.ss.cfg.numMaskSgprPerElement, numElementsPerBatch)) - #kStr += "// storeStats, %d, %d, %d\n"% (edgeI, numSgprs, numElementsPerBatch) + print("warning: %s growing VGPR for GlobalWrite batching - this may bloat VGPR usage" % \ + (self.kernelName)) + print(" numVgprAvailable=", numVgprAvailable, \ + "numVgprsPerElement=", numVgprsPerElement, "atomic=", atomic, \ + "beta=", beta, "gwvw=", gwvw) + elif gwvw != gwvwOrig: + self.ss.gwvw = gwvw # make both representations consistent + if shrinkDb: + print2("info: %s shrank gwvw from %u to %u but kept occupancy same=%u." \ + % (self.kernelName, gwvwOrig, gwvw, currentOccupancy)) + + if numVgprAvailable < minElements*numVgprsPerElement: + print2("info: growing pool += %d * %d for GlobalWrite\n" \ + % (minElements,numVgprsPerElement)) + print2(self.vgprPool.state()) + tl = [] + for i in range(0,minElements): + tl.append(self.vgprPool.checkOut(numVgprsPerElement, "grow-pool for GlobalWrite")) + for t in tl: + self.vgprPool.checkIn(t) + numVgprAvailable = self.vgprPool.available() + print2(self.vgprPool.state()) + + # set atomicW after we potentially resize GWVW + atomicW = min(gwvw, kernel["VectorAtomicWidth"]) + + # print("NumVgprAvailable", numVgprAvailable) + if numVgprsPerElement: + numElementsPerBatch = numVgprAvailable // numVgprsPerElement + else: + numElementsPerBatch = len(elements[edgeI]) # max, do 'em all + + assert(self.numVgprValuC % gwvw == 0) # sanity check + + numElementsPerBatch = numElementsPerBatch if not kernel["NumElementsPerBatchStore"] else min(kernel["NumElementsPerBatchStore"],numElementsPerBatch) + + if shrinkDb: + print("NumElementsPerBatch=", numElementsPerBatch, "LimitedBySgprs=", self.ss.cfg.numElementsPerBatchLimitedBySgprs, \ + "WARNING" if self.ss.cfg.numElementsPerBatchLimitedBySgprs < numElementsPerBatch else "okay") + if self.ss.cfg.numElementsPerBatchLimitedBySgprs < numElementsPerBatch: + numElementsPerBatch = self.ss.cfg.numElementsPerBatchLimitedBySgprs + + # TODO: Which of DataType or DestDataType is in a better sense? 0114: Check Using DestDataType + HSS + if (kernel["ProblemType"]["DataType"].isHalf() or kernel["ProblemType"]["DataType"].isBFloat16()): + # only do an even number of halves - since these share hi/lo pieces of some registers? + if numElementsPerBatch > 1: + numElementsPerBatch = int(numElementsPerBatch/2)*2 + elif not kernel["EnableMatrixInstruction"]: + # (excluding MFMA+LSU case. It can work without an issue) + # The globalWriteBatch routine below can't handle odd elements per batch + # and 0 elements per batch is illegal. # so if we don't have *GPR resources to handle a larger batch then need # to mark overflowedResources rather than generate a kernel that won't work. - tmpSgpr = self.getTmpSgpr(numSgprs, 2).idx() + # It might be possible to fix globalWriteBatch to handle this case but these + # are likely to be low-performing so likely not worth optimizing. + if shrinkDb: + print("WARNING: half requires at least two elements per batch") + self.overflowedResources = 3 + #elif kernel["ProblemType"]["DataType"].is8bitFloat(): + # if numElementsPerBatch > 1: + # numElementsPerBatch = int(numElementsPerBatch/4)*4 + + assert numElementsPerBatch > 0, "numElementsPerBatch=0 for %s"%self.kernelName + + #numElementsPerBatch=min(2,numElementsPerBatch) # hack to control number of batches + if atomic and (self.ss.optSingleColVgpr or self.ss.optSharedColVgpr): + # hack to avoid re-using address vgpr across rows + # atomics need to perform several memory operations + # if the batch spans multiple rows, need multiple address vgpr + # which is not currently supported in the two opt*ColVgpr modes + firstRow = [e for e in elements[edgeI] if e[0]==0 and e[2]==0] + numElementsPerBatch=min(len(firstRow),numElementsPerBatch) + + # check best numElementsPerBatch to handle a column block + # elements of column block must be multiple size of numElementsPerBatch + if kernel["StoreRemapVectorWidth"]: + firstRow = [e for e in elements[edgeI] if e[0]==0 and e[2]==0] # format for element = (tt1, tt0, vc1, vc0) + # find the largest factor and smaller than numElementPerBatch + nBatchesPerRow = 1 + for d in range(1, len(firstRow)+1): + largestFactor = len(firstRow)//d + if len(firstRow)%d == 0 and largestFactor <= numElementsPerBatch: + numElementsPerBatch = largestFactor + nBatchesPerRow = d + break + + # if no atomics and no edge, then write whole vectors + #if not atomic and not edge: + # numVectorsPerBatch = numElementsPerBatch / kernel["GlobalWriteVectorWidth"] + # #print " NumVectorsPerBatch", numVectorsPerBatch + # numElementsPerBatch = numVectorsPerBatch * kernel["GlobalWriteVectorWidth"] + numBatches = max(1, ceil_divide(len(elements[edgeI]),numElementsPerBatch)) + + numSgprs = self.ss.cfg.numTempSgprPerBatch + self.ss.cfg.numMaskSgprPerBatch + self.ss.cfg.numMaskSgprPerElement * numElementsPerBatch + + if self.db["PrintStoreRegisterDb"]: + print("edgeI", edgeI, "NumBatches", numBatches, "NumElementsPerBatch", numElementsPerBatch, "numVgprsPerElement", numVgprsPerElement, "len(elements[edgeI])", len(elements[edgeI])) + print ("numSgprs=", numSgprs, "sgprPool.size()=", self.sgprPool.size(), "numTempSgprPerBatch=", self.ss.cfg.numTempSgprPerBatch, + "numMaskSgprPerBatch=", self.ss.cfg.numMaskSgprPerBatch, "numMaskSgprPerElement=", self.ss.cfg.numMaskSgprPerElement) + print(self.sgprPool.state()) + kStr += self.comment("edge=%d, allocate %u sgpr. perBatchTmpS=%u perBatchMaskS=%u perElementMaskS=%u elementsPerBatch=%u" % + (edgeI, numSgprs, self.ss.cfg.numTempSgprPerBatch, self.ss.cfg.numMaskSgprPerBatch, self.ss.cfg.numMaskSgprPerElement, numElementsPerBatch)) + #kStr += "// storeStats, %d, %d, %d\n"% (edgeI, numSgprs, numElementsPerBatch) + # so if we don't have *GPR resources to handle a larger batch then need + # to mark overflowedResources rather than generate a kernel that won't work. + tmpSgpr = self.getTmpSgpr(numSgprs, 2).idx() + + elementSgprs = tmpSgpr + self.ss.cfg.numTempSgprPerBatch + + codeAccVgprRead = deepcopy(self.codeAccVgprRead) if self.serializedStore else None + codeMulAlpha = deepcopy(self.codeMulAlpha) if self.serializedStore else None + + self.alphaBeforeLoadC = False + useCodeMulAlpha = kernel["MIArchVgpr"] and alpha and not (kernel["GlobalSplitU"] > 1) + if useCodeMulAlpha: # do not set codeAccVgprRead=None if GSU>1 + codeAccVgprRead = None + + #Only apply when 2 wave optimization features are enabled + if (kernel["StorePriorityOpt"] or kernel["StoreSyncOpt"]) and beta: + self.alphaBeforeLoadC = True + else: + codeMulAlpha = None + + for batchIdx in range(0, numBatches): + elementStartIdx = batchIdx * numElementsPerBatch + elementStopIdx = min( elementStartIdx + numElementsPerBatch, len(elements[edgeI]) ) + elementsThisBatch = elements[edgeI][elementStartIdx:elementStopIdx] + #print("BATCH[%u/%u]: elements[edgeI][%u:%u] VGPRs=%u" % (batchIdx, numBatches, elementStartIdx, elementStopIdx,numVgprsPerElement )) + # elementVgprs can be large and should be perfectly tuned to the number of available + # VGPRS. We do not want to accidentally overflow and grow the pool here: - elementSgprs = tmpSgpr + self.ss.cfg.numTempSgprPerBatch + if kernel["StoreRemapVectorWidth"]: + #Indication if this batch is last batch for this column block shape + self.StoreRemapLastBatch = 1 if (batchIdx+1) % nBatchesPerRow == 0 else 0 + + kStr += self.globalWriteBatch(kernel, self.ss, batchIdx, alpha, beta, edge, atomic, gwvw, atomicW, \ + elementsThisBatch, self.coord0, self.coord1, self.addrD, self.addrC, \ + tmpVgpr, tmpCVTVgpr, \ + elementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, isOptNLL) + # delay PreLoopVmcntCase code after globalWrite + if self.canOptimizePreLoopLWVmcnt: + kStr += PreLoopVmcntCaseStr - codeAccVgprRead = deepcopy(self.codeAccVgprRead) if self.serializedStore else None - codeMulAlpha = deepcopy(self.codeMulAlpha) if self.serializedStore else None + # TODO - if this is the last tile, don't need to jump to next instruction + # NOTE: in SR kernel, we need long branch since PRNG explodes the line of codes + if kernel["ProblemType"]["StochasticRounding"]: # in-device RND + endLabelName = "label_%s"%(endLabel) + kStr += self.longBranchPositive(endLabelName) + else: + kStr += inst("s_branch", "label_%s"%endLabel, "jump to end") + del self.ss - self.alphaBeforeLoadC = False - if useCodeMulAlpha: # do not set codeAccVgprRead=None if GSU>1 - codeAccVgprRead = None + # Finish one write path, reset currPreLoopVmcntCase to Undefined + self.currPreLoopVmcntCase = PreLoopVmcntCase.Undefined - #Only apply when 2 wave optimization features are enabled - if (kernel["StorePriorityOpt"] or kernel["StoreSyncOpt"]) and beta: - self.alphaBeforeLoadC = True - else: - codeMulAlpha = None + return kStr - for batchIdx in range(0, numBatches): - elementStartIdx = batchIdx * numElementsPerBatch - elementStopIdx = min( elementStartIdx + numElementsPerBatch, len(elements[edgeI]) ) - elementsThisBatch = elements[edgeI][elementStartIdx:elementStopIdx] - #print("BATCH[%u/%u]: elements[edgeI][%u:%u] VGPRs=%u" % (batchIdx, numBatches, elementStartIdx, elementStopIdx,numVgprsPerElement )) - # elementVgprs can be large and should be perfectly tuned to the number of available - # VGPRS. We do not want to accidentally overflow and grow the pool here: - if kernel["StoreRemapVectorWidth"]: - #Indication if this batch is last batch for this column block shape - self.StoreRemapLastBatch = 1 if (batchIdx+1) % nBatchesPerRow == 0 else 0 - - kStr += self.globalWriteBatch(kernel, self.ss, batchIdx, applyAlpha, beta, edge, atomic, gwvw, atomicW, \ - elementsThisBatch, self.coord0, self.coord1, self.addrD, self.addrC, \ - tmpVgpr, tmpCVTVgpr, \ - elementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, isOptNLL) - # delay PreLoopVmcntCase code after globalWrite - if self.canOptimizePreLoopLWVmcnt: - kStr += PreLoopVmcntCaseStr - - # TODO - if this is the last tile, don't need to jump to next instruction - # NOTE: in SR kernel, we need long branch since PRNG explodes the line of codes - if kernel["ProblemType"]["StochasticRounding"]: # in-device RND - endLabelName = "label_%s"%(endLabel) - kStr += self.longBranchPositive(endLabelName) - else: - kStr += inst("s_branch", "label_%s"%endLabel, "jump to end") - del self.ss + ############################################################################## + # Global Write Elements + ############################################################################## + def globalWriteElements(self, kernel, vectorWidths, elements, + applyAlpha=True, # defaults to generating *=alpha codes + betas=None, # if left unspecified, then let global parameter decide + edges=None, + isOptNLL=False): # if OptNLL or not (for StoreCInUnroll) + if not kernel["StoreCInUnroll"]: + if not self.do["PostLoop"]: return "" + kStr = "" + atomic = (kernel["GlobalSplitU"] > 1) and (kernel["_GlobalAccumulation"] != 'MultipleBuffer') + atomic = atomic or kernel["StreamK"] == 1 + useCodeMulAlpha = kernel["MIArchVgpr"] and applyAlpha and not (kernel["GlobalSplitU"] > 1) + + # write possibilities and labels + # if beta/edge combo not specified fall back to global param definition + if betas is None: + hasBeta = kernel["ProblemType"]["UseBeta"] and (kernel["GlobalSplitU"] == 1) + betas = [False, True] if hasBeta else [False] + if edges is None: + edges = [False, True] if self.do["EdgeWrite"] else [False] + writeLabels = {} + for beta in betas: + writeLabels[beta] = {} + for edge in edges: + writeLabels[beta]["EdgeCheck0"] = self.getNamedLabelUnique("GW_B%u_E%u_EdgeCheck0" % ( 1 if beta else 0, 1 if edge else 0) ) + writeLabels[beta]["EdgeCheck1"] = self.getNamedLabelUnique("GW_B%u_E%u_EdgeCheck1" % ( 1 if beta else 0, 1 if edge else 0) ) + writeLabels[beta][edge] = self.getNamedLabelUnique("GW_B%u_E%u" % ( 1 if beta else 0, 1 if edge else 0) ) + if not beta: + betaLabel = self.getNamedLabelUnique("GW_Beta") + endLabel = self.getNamedLabelUnique("GW_End") + + # Layout + """ + if B1 goto label_B1 + if E1 goto label_B0_E1 + label_B0_E0: + writes + goto label_End + label_B0_E1: + writes + goto label_End + label_B1: + if E1 goto label_B1_E1 + label_B1_E0: + writes + goto label_End + label_B1_E1: + writes + goto label_End + label_End + """ + self.betaVgpr = None + + ######################################## + # Vgprs + if kernel["BufferStore"]: + numTmpVgpr = 2 + if len(kernel["PackedC0IndicesX"]) > 1: + numTmpVgpr += 1 + else: + numTmpVgpr = 2 + 3 # GLOBAL_OFFSET_C needs 3, plus 2 tmps? + if useCodeMulAlpha and kernel["ProblemType"]["DataType"].isComplex(): + # codeMulAlpha and complex caes, use tmpVgpr for alpha calculation + numTmpVgpr = max(numTmpVgpr, kernel["ProblemType"]["DataType"].numRegisters()) + tmpVgpr = self.vgprPool.checkOutAligned(numTmpVgpr, 2, "store tmps") + + isHpaBF16 = kernel["ProblemType"]["DestDataType"].isBFloat16() and kernel["ProblemType"]["HighPrecisionAccumulate"] + isHpaF8 = kernel["ProblemType"]["DestDataType"].isFloat8() or kernel["ProblemType"]["DestDataType"].isBFloat8() # F8 is always HPA + + # need temp vgpr both for bf16 and f8 + if isHpaF8: + rcnt = 4 + if kernel["ProblemType"]["Fp32toFp8SWClip"]: + rcnt += 1 + if kernel["ProblemType"]["StochasticRounding"]: + rcnt += 2 + tmpCVTVgpr = self.vgprPool.checkOut(rcnt) + elif isHpaBF16: + tmpCVTVgpr = self.vgprPool.checkOut(4) + else: + tmpCVTVgpr = None + + ######################################## + # Sgprs + + # allocate tmps for the store header (before the batch implementations) + # tmpSgprRef = self.getTmpSgpr(4) + # tmpSgpr = tmpSgprRef.idx() - # Finish one write path, reset currPreLoopVmcntCase to Undefined - self.currPreLoopVmcntCase = PreLoopVmcntCase.Undefined + # branch B1 or B0 + betaLabel = self.getNamedLabelUnique("GW_Beta") + skPartialsLabel = self.getNamedLabelUnique("SK_Partials") + skFixupLabel = self.getNamedLabelUnique("SK_Fixup") + skStoreLabel = self.getNamedLabelUnique("SK_Store") + + if kernel["StreamK"] == 2 or kernel["StreamK"] == 3: + # StreamK store branches + tmpSgpr = self.sgprPool.checkOut(4, "globalWriteElements", preventOverflow=0) + # if we did not start the tile, store partials + # branch to beta == 0 store path + kStr += inst("s_cmp_eq_u32", sgpr("StreamKLocalStart"), 0, "does wg start tile?") + kStr += inst("s_cbranch_scc0 %s" % skPartialsLabel, "Branch if not start tile, store partials") + # if we started and finished the tile, regular store code + # branch to regular store code, skip fixup step + kStr += inst("s_cmp_eq_u32", sgpr("StreamKLocalEnd"), sgpr("ItersPerTile"), "does wg finish tile?") + kStr += inst("s_cbranch_scc1 %s" % skStoreLabel, "Branch if started and finished tile, go to regular store code") + + # if we started the tile but did not finish it, fix up step + # run fixup code before regular store code + sCtaIdx = self.sgprPool.checkOut(1, "CtaIdx", preventOverflow=0) # self.defineSgpr("CtaIdx", 1) + kStr += inst("s_add_u32", sgpr(sCtaIdx), sgpr("StreamKIdx"), 1, "input partial tile index") + + sFixupEnd = self.sgprPool.checkOut(1, "FixupEnd", preventOverflow=0) # self.defineSgpr("CtaEnd", 1) + kStr += self.sMagicDivAlg2(kernel, tmpSgpr, sgpr("StreamKIterEnd"), sgpr("MagicNumberItersPerTile"), sgpr("MagicShiftItersPerTile")) + kStr += inst("s_mul_i32", sgpr(tmpSgpr), sgpr(tmpSgpr), sgpr("ItersPerTile"), "start iteration of partial tile") + kStr += inst("s_sub_u32", sgpr(sFixupEnd), sgpr("StreamKIterEnd"), sgpr(tmpSgpr), "calc iterations completed by this WG") + + kStr += "%s:\n" % (skFixupLabel) + + # Check flag + kStr += inst("s_lshl_b32", sgpr(tmpSgpr), sgpr(sCtaIdx), log2(4), "flag offset based on CTA index") + kStr += inst("s_load_dword", sgpr(tmpSgpr+2), sgpr("AddressFlags", 2), sgpr(tmpSgpr), "glc", "get flag") + kStr += inst("s_waitcnt", "lgkmcnt(0)", "wait for flag load") + kStr += inst("s_cmp_eq_u32", sgpr(tmpSgpr+2), 1, "check if ready") + kStr += inst("s_cbranch_scc0 %s" % skFixupLabel, "if flag not set, wait and check again") + + self.sgprPool.checkIn(tmpSgpr) + + fixupEdge = [False] # Temporary hack to test no edge variant + kStr += self.fixupStep(kernel, vectorWidths, elements, fixupEdge, tmpVgpr, tmpCVTVgpr, sCtaIdx, skStoreLabel) + + if kernel["StreamK"] == 3: + sIterCount = self.sgprPool.checkOut(1, "iterCount", preventOverflow=0) + kStr += inst("s_add_u32", sgpr(sIterCount), sgpr("SKItersPerWG"), 1, "Add extra iter") + kStr += inst("s_cmp_lt_u32", sgpr(sCtaIdx), sgpr("skExtraIters"), "Check if next WG had an extra iteration") + kStr += inst("s_cselect_b32", sgpr(sIterCount), sgpr(sIterCount), sgpr("SKItersPerWG"), "Select correct number of iterations for next WG") + kStr += inst("s_add_u32", sgpr(sFixupEnd), sgpr(sFixupEnd), sgpr(sIterCount), "next partial tile iteration") + self.sgprPool.checkIn(sIterCount) + kStr += inst("s_add_u32", sgpr(sCtaIdx), sgpr(sCtaIdx), 1, "next partial tile index") + if kernel["StreamK"] == 2: + kStr += inst("s_add_u32", sgpr(sFixupEnd), sgpr(sFixupEnd), sgpr("SKItersPerWG"), "next partial tile iteration") + kStr += inst("s_cmp_lt_u32", sgpr(sFixupEnd), sgpr("ItersPerTile"), "done loading partial tiles?") + kStr += inst("s_cbranch_scc1 %s" % skFixupLabel, "Branch to continue fixup loop") + kStr += "%s:\n" % (skStoreLabel) + + self.sgprPool.checkIn(sFixupEnd) + self.sgprPool.checkIn(sCtaIdx) + + if False in betas and True in betas: + kStr += self.checkIsBetaZero(kernel, betaLabel) + + for beta in betas: + # start B1 + if beta: + kStr += "%s:\n"%(betaLabel) + + ######################################## + # branch if Edge0 or Edge1 + if False in edges and True in edges: + kStr += self.checkIsEdge(kernel, "%s" % writeLabels[beta][True]) + + # by now we either jumped to E1 or stayed at E0 + for edge in edges: + # write label for batch case + kStr += "%s:%s"%(writeLabels[beta][edge], self.endLine) + kStr += self.globalWriteProcedure(kernel, vectorWidths, elements, applyAlpha, beta, edge, atomic, tmpVgpr, tmpCVTVgpr, isOptNLL, endLabel) + + if kernel["StreamK"] == 2 or kernel["StreamK"] == 3: + kStr += "%s:\n" % (skPartialsLabel) + fixupEdge = [False] # Temporary hack to test no edge variant + kStr += self.writePartials(kernel, vectorWidths, elements, fixupEdge, atomic, tmpVgpr, tmpCVTVgpr, isOptNLL, endLabel) + # End label kStr += "label_%s:%s"%(endLabel, self.endLine) self.vgprPool.checkIn(tmpVgpr) @@ -12010,7 +13577,7 @@ def chooseGlobalRead(self, useBuffer, bpl, destVgpr, \ ############################################################################## def chooseGlobalWrite(self, useBuffer, bps, srcVgpr, rpv, \ - addr0, addr1, offset, extraFields, hi16=0, vb1Tmp=0): + addr0, addr1, offset, extraFields, hi16=0, vb1Tmp=0, soffset=None): """ create the store instruction for requested vector width and other parms rpv = regs per vector @@ -12027,6 +13594,10 @@ def chooseGlobalWrite(self, useBuffer, bps, srcVgpr, rpv, \ kStr += inst("s_mov_b32", tmpSgpr, offset, "large offset") offset = 0 + if soffset != None: + assert offset < 4096, "sgpr offset provided with large const offset" + tmpSgpr = soffset + if bps == 1 : #TODO: need to use _buffer_store_b8 macro.. but conflict with buffer_store_byte_d16 in chosen_store if hi16 == 0: @@ -12105,12 +13676,16 @@ def chooseGlobalWrite(self, useBuffer, bps, srcVgpr, rpv, \ return kStr ############################################################################## - def addStore(self, kernel, ss, addrCalc, sumIdx, tmpS01, edge): + def addStore(self, kernel, ss, addrCalc, sumIdx, tmpS01, edge, tc='D', wsOffset=None): """ Add stores for the element with addrCalc and sumIdx. tmpS01 is a single :temp sGPR """ kStr = "" + isWorkspace = tc == 'WS' + dataType = kernel["ProblemType"]["DestDataType"] + if isWorkspace: + dataType = kernel["ProblemType"]["ComputeDataType"] if self.do["GlobalWrite"]: # perform vector stores here, so no VI indexing. # if GWVW > Vw, might need to support loops to @@ -12123,49 +13698,57 @@ def addStore(self, kernel, ss, addrCalc, sumIdx, tmpS01, edge): bps = self.bpeCexternal * ss.cfg.gwvw rpv = self.bpeCexternal * ss.cfg.gwvw / self.bpr + if isWorkspace: + bps = self.bpeCinternal * ss.cfg.gwvw + rpv = self.bpeCinternal * ss.cfg.gwvw / self.bpr if kernel["BufferStore"]: addr0 = vgpr(addrCalc.addrDVgpr) - addr1 = sgpr("SrdD", 4) + addr1 = sgpr("Srd%s" % (tc), 4) else: addr0 = vgpr(addrCalc.addrDVgpr,2) addr1 = "" useBuffer = kernel["BufferStore"] - if ss.optSrdIncForRow and addrCalc.rowInc: + if ss.optSrdIncForRow and addrCalc.rowInc and not isWorkspace: kStr += addrCalc.incrementToNextRow(kernel, "D", ss, tmpS01) - if kernel["ProblemType"]["DestDataType"].isHalf() or kernel["ProblemType"]["DestDataType"].isBFloat16(): + + offset = addrCalc.globalOffset + if isWorkspace: + offset = 0 + + if dataType.isHalf() or dataType.isBFloat16(): if not kernel["ProblemType"]["HighPrecisionAccumulate"]: # (H,H,H,H,H,H), internal H if globalParameters["AsmCaps"][self.version]["HasWMMA"] and kernel["EnableMatrixInstructionStore"]: - kStr += self.chooseGlobalWrite(useBuffer, bps, sumIdx, rpv, addr0, addr1, addrCalc.globalOffset, ntStr, hi16=0) + kStr += self.chooseGlobalWrite(useBuffer, bps, sumIdx, rpv, addr0, addr1, offset, ntStr, hi16=0, vb1Tmp=0, soffset=wsOffset) else: - kStr += self.chooseGlobalWrite(useBuffer, bps, sumIdx//2, rpv, addr0, addr1, addrCalc.globalOffset, ntStr, hi16=sumIdx%2) + kStr += self.chooseGlobalWrite(useBuffer, bps, sumIdx//2, rpv, addr0, addr1, offset, ntStr, hi16=sumIdx%2, vb1Tmp=0, soffset=wsOffset) else: # (B,B,B,B,S,S), internal S # (H,H,H,H,H,H), internal S # (H,H,H,H,S,S), internal S kStr += self.chooseGlobalWrite(useBuffer, bps, sumIdx, rpv, \ - addr0, addr1, addrCalc.globalOffset, ntStr, hi16=0) + addr0, addr1, offset, ntStr, hi16=0, vb1Tmp=0, soffset=wsOffset) #TODO: need to test on emulator, always HPA for f8, data is already packed in dest - elif kernel["ProblemType"]["DestDataType"].isFloat8() or kernel["ProblemType"]["DestDataType"].isBFloat8(): + elif dataType.isFloat8() or dataType.isBFloat8(): kStr += self.chooseGlobalWrite(useBuffer, bps, sumIdx, rpv, \ - addr0, addr1, addrCalc.globalOffset, ntStr) + addr0, addr1, offset, ntStr, hi16=0, vb1Tmp=0, soffset=wsOffset) - elif kernel["ProblemType"]["DestDataType"].isInt32() or kernel["ProblemType"]["DestDataType"].isSingle(): + elif dataType.isInt32() or dataType.isSingle(): kStr += self.chooseGlobalWrite(useBuffer, bps, sumIdx, rpv, \ - addr0, addr1, addrCalc.globalOffset, ntStr) - elif kernel["ProblemType"]["DestDataType"].isDouble() or kernel["ProblemType"]["DestDataType"].isSingleComplex(): + addr0, addr1, offset, ntStr, hi16=0, vb1Tmp=0, soffset=wsOffset) + elif dataType.isDouble() or dataType.isSingleComplex(): if kernel["AtomicAddC"] and not edge: kStr += inst("buffer_atomic_add_f64", vgpr(sumIdx*2, 2), vgpr(addrCalc.addrDVgpr), sgpr("SrdD", 4), "0", "offen offset:{}".format(addrCalc.globalOffset), "AtomicAddC") else: kStr += self.chooseGlobalWrite(useBuffer, bps, sumIdx*2, rpv, \ - addr0, addr1, addrCalc.globalOffset, ntStr) - elif kernel["ProblemType"]["DestDataType"].isDoubleComplex(): - rps = kernel["ProblemType"]["DestDataType"].numRegisters() + addr0, addr1, offset, ntStr, hi16=0, vb1Tmp=0, soffset=wsOffset) + elif dataType.isDoubleComplex(): + rps = dataType.numRegisters() kStr += self.chooseGlobalWrite(useBuffer, bps, sumIdx*rps, rpv, \ - addr0, addr1, addrCalc.globalOffset, ntStr) + addr0, addr1, offset, ntStr, hi16=0, vb1Tmp=0, soffset=wsOffset) return kStr @@ -12299,14 +13882,18 @@ def applyAlpha(self, kernel, gwvw, elementSumIdx, elementIdx, tmpS01): ############################################################################## # Global Read C Input ############################################################################## - def readCInput(self, kernel, ss, addrCalc, vc0, data, gwvw, addr, tmpS01): + def readCInput(self, kernel, ss, addrCalc, vc0, data, gwvw, addr, tmpS01, tc='C'): kStr = "" - bps = kernel["ProblemType"]["DestDataType"].numBytes() * gwvw useBuffer = kernel["BufferStore"] + isWorkspace = tc == 'WS' + dataType = kernel["ProblemType"]["DestDataType"] + if isWorkspace: + dataType = kernel["ProblemType"]["ComputeDataType"] + bps = dataType.numBytes() * gwvw if kernel["BufferStore"]: addr0 = vgpr(addr) - addr1 = sgpr("SrdC", 4) + addr1 = sgpr("Srd%s"%(tc), 4) else: addr0 = vgpr(addr,2) addr1 = "" @@ -12317,33 +13904,41 @@ def readCInput(self, kernel, ss, addrCalc, vc0, data, gwvw, addr, tmpS01): if kernel["NonTemporalC"]//2==1: extraStr += " " + getSlcBitName(kernel["MemoryModifierFormat"]) - if ss.optSrdIncForRow and addrCalc.rowInc: - kStr += addrCalc.incrementToNextRow(kernel, "C", ss, tmpS01) + if ss.optSrdIncForRow and addrCalc.rowInc and not isWorkspace: + kStr += addrCalc.incrementToNextRow(kernel, tc, ss, tmpS01) + + soffset = 0 + offset = addrCalc.globalOffset + comment = "load C for beta calc" + if isWorkspace: + soffset = tmpS01 + offset = 0 + comment = "load partials" - if kernel["ProblemType"]["DestDataType"].isHalf(): + if dataType.isHalf(): hi16 = 0 if self.HHH_WMMA else (vc0 % 2) kStr += self.chooseGlobalRead(useBuffer, bps, data, \ - addr0, addr1, soffset=0, offset=addrCalc.globalOffset, \ + addr0, addr1, soffset=soffset, offset=offset, \ extraFields=extraStr, dtlNoDestVgpr=False, hi16=hi16, \ - comment="load C for beta calc").toStr() + comment=comment).toStr() - elif kernel["ProblemType"]["DestDataType"].isFloat8() or \ - kernel["ProblemType"]["DestDataType"].isBFloat8(): + elif dataType.isFloat8() or \ + dataType.isBFloat8(): kStr += self.chooseGlobalRead(useBuffer, bps, data, \ - addr0, addr1, soffset=0, offset=addrCalc.globalOffset, \ + addr0, addr1, soffset=soffset, offset=offset, \ extraFields=extraStr, dtlNoDestVgpr=False, hi16=0, ubyteLoad=1, \ - comment="load C for beta calc").toStr() - - elif kernel["ProblemType"]["DestDataType"].isBFloat16() or \ - kernel["ProblemType"]["DestDataType"].isInt32() or \ - kernel["ProblemType"]["DestDataType"].isSingle() or \ - kernel["ProblemType"]["DestDataType"].isDouble() or \ - kernel["ProblemType"]["DestDataType"].isSingleComplex() or \ - kernel["ProblemType"]["DestDataType"].isDoubleComplex(): + comment=comment).toStr() + + elif dataType.isBFloat16() or \ + dataType.isInt32() or \ + dataType.isSingle() or \ + dataType.isDouble() or \ + dataType.isSingleComplex() or \ + dataType.isDoubleComplex(): kStr += self.chooseGlobalRead(useBuffer, bps, data, \ - addr0, addr1, soffset=0, offset=addrCalc.globalOffset, \ + addr0, addr1, soffset=soffset, offset=offset, \ extraFields=extraStr, dtlNoDestVgpr=False, \ - comment="load C for beta calc").toStr() + comment=comment).toStr() return kStr @@ -12370,8 +13965,8 @@ def globalWriteBatch(self, kernel, ss, batchIdx, applyAlpha, beta, edge, atomic, # comment tt1, tt0, vc1, vc0 # tt = thread tile, vc=vector component - commentStr = "Global Write%s%s Batch #%u (d1,d0,vc1,vc0) =\n " \ - % (" Beta" if beta else "", " Edge" if edge else "", batchIdx) + commentStr = "Global Write%s%s%s Batch #%u (d1,d0,vc1,vc0) =\n " \ + % (" Alpha" if applyAlpha else "", " Beta" if beta else "", " Edge" if edge else "", batchIdx) for elementIdx in range(0, len(batchElements)): element = batchElements[elementIdx] commentStr += "(%u,%u,%u,%u:vw%u%s)" % \ @@ -12624,8 +14219,8 @@ def globalWriteBatch(self, kernel, ss, batchIdx, applyAlpha, beta, edge, atomic, d0 = element[1] vc1 = element[2] vc0 = element[3] - labelString = "Global_Write%s%s_vc=%u,%u_d=%u,%u" \ - % (" Beta" if beta else "", " Edge" if edge else "", vc0, vc1, d0, d1 ) + labelString = "Global_Write%s%s%s_vc=%u,%u_d=%u,%u" \ + % (" Beta" if beta else "", " Edge" if edge else "", " Opt" if isOptNLL else "", vc0, vc1, d0, d1 ) label = self.getLabelNum(labelString) labelString += "EarlyExit" labelAfterAtomicLoop = self.getLabelNum(labelString) @@ -13492,7 +15087,12 @@ def closePrefetchGlobalRead2(self, kernel): ############################################################################## def persistentLoopendLongjump(self, kernel): kStr = "" - if kernel["PersistentKernel"]: + if kernel["StreamK"]: + endIter = "StreamKIterEnd" if kernel["StreamK"] < 3 else "TotalIters" + kStr += inst("s_cmp_ge_u32", sgpr("StreamKIter"), sgpr(endIter), "Check if done all StreamK iterations") + kStr += self.longBranchScc0(self.getLabelTarget("PersistentLoopStart"), negativeOnly=True) + + if kernel["PersistentKernel"]: # or kernel["StreamK"]: # Persistent may generate a SerialWorkGroupIter which is OOB, only loop back if we are in a valid WG: stmp = self.getTmpSgpr(1).idx() kStr += inst("s_mul_i32", sgpr(stmp), sgpr("NumWorkGroups0"), sgpr("NumWorkGroups1"), "Total WG-0x1") @@ -13508,7 +15108,7 @@ def persistentLoopendLongjump(self, kernel): ############################################################################## def functionEnd(self, kernel, addLabel=True): imod = Code.Module() - if kernel["PersistentKernel"]: + if kernel["PersistentKernel"] or kernel["StreamK"]: if kernel["StoreCInUnroll"]: # StoreCInUnroll case, reset StoreCAvail here for the next persistent loop (StoreCInUnroll disabled) imod.addCode(self.resetStoreCsyncObject(kernel)) @@ -14726,8 +16326,12 @@ def MapAcctoArchRegs(self, kernel, option, isOptNLL=False): acc2arch, _ = self.AccToArchMapper(kernel) complexMultiplier = 2 if kernel["ProblemType"]["DataType"].isComplex() else 1 + streamK = (kernel["StreamK"] == 2 or kernel["StreamK"] == 3) self.codeAccVgprRead = Code.Module("AccVgprRead") self.codeAccVgprRead.itemList = [None] * kernel["MIRegPerOut"] * complexMultiplier * len(acc2arch) + if streamK: + self.codeAccVgprWrite = Code.Module("AccVgprWrite") + self.codeAccVgprWrite.itemList = [None] * kernel["MIRegPerOut"] * complexMultiplier * len(acc2arch) # srcIdxList for MFMA+LSU+MIArchVgpr # This is used to store accumulated values in vgpr to LDS without v_mov instructions # Exception for complex case and StoreVectorWidth>1 case. Need to reorder vgpr with v_mov instructions @@ -14754,10 +16358,20 @@ def MapAcctoArchRegs(self, kernel, option, isOptNLL=False): self.codeAccVgprRead.itemList[destIdx] = Code.Inst("v_accvgpr_read_b32", vgpr("ValuC+__placeholder__"), accStr, "copy acc to vreg[%u]" % destIdx) + if streamK: + self.codeAccVgprWrite.itemList[destIdx] = Code.Inst("v_accvgpr_write_b32", + accStr, + vgpr("ValuC+__placeholder__"), + "copy vreg[%u] to acc" % destIdx) else: self.codeAccVgprRead.itemList[destIdx] = Code.Inst("v_mov_b32", vgpr("ValuC+__placeholder__"), vgpr("ValuC+%u"%srcIdx), "copy MI out reg to vreg[%u]" % destIdx) + if streamK: + self.codeAccVgprWrite.itemList[destIdx] = Code.Inst("v_mov_b32", + vgpr("ValuC+%u"%srcIdx), + vgpr("ValuC+__placeholder__"), "copy vreg[%u] to MI out reg" % destIdx) + if noMove: # keep srcIdx for MFMA+LSU+MIArchVgpr self.srcIdxList[destIdx] = srcIdx diff --git a/Tensile/KernelWriterConversion.py b/Tensile/KernelWriterConversion.py index 26f1b9a80..a612562da 100644 --- a/Tensile/KernelWriterConversion.py +++ b/Tensile/KernelWriterConversion.py @@ -34,6 +34,13 @@ def __init__(self, state): self.state["ProblemType"] = deepcopy(state["ProblemType"]) self.state["_GlobalAccumulation"] = state["_GlobalAccumulation"] + self.state["GlobalSplitU"] = state["GlobalSplitU"] if state["_GlobalAccumulation"] == 'MultipleBuffer' else 1 + self.state["GSUUnrollUnit"] = state["GSUUnrollUnit"] # number of unroll for large GSU + # mod part of GSU unroll. This will be fully unrolled. + mod = self.state["GlobalSplitU"] % self.state["GSUUnrollUnit"] + self.state["GSUmod"] = self.state["GSUUnrollUnit"] if mod == 0 else mod + self.state["VectorWidth"] = state["VectorWidth"] + self.state["Reduction"] = state["Reduction"] # derive parameter self.language = "HIP" @@ -75,7 +82,9 @@ def functionSignature(self): # alpha & beta kStr += " %s const alpha,%s" % (self.state["ProblemType"]["ComputeDataType"].toDevice(self.language), self.endLine) - kStr += " %s const beta,%s" % (self.state["ProblemType"]["ComputeDataType"].toDevice(self.language), self.endLine) + kStr += " %s const beta" % (self.state["ProblemType"]["ComputeDataType"].toDevice(self.language)) + + midEnd = ",%s"%self.endLine # strides firstStrideCD = 1 @@ -83,23 +92,25 @@ def functionSignature(self): firstStrideCD = 0 lastStrideC = self.state["ProblemType"]["NumIndicesC"] for i in range(firstStrideCD, lastStrideC): - kStr += " unsigned int const strideD%s,%s" % (self.indexChars[i], self.endLine) + kStr += "%s unsigned int const strideD%s" % (midEnd, self.indexChars[i]) for i in range(firstStrideCD, lastStrideC): - kStr += " unsigned int const strideW%s,%s" % (self.indexChars[i], self.endLine) + kStr += "%s unsigned int const strideW%s" % (midEnd, self.indexChars[i]) for i in range(firstStrideCD, lastStrideC): - kStr += " unsigned int const strideC%s,%s" % (self.indexChars[i], self.endLine) + kStr += "%s unsigned int const strideC%s" % (midEnd, self.indexChars[i]) # sizes for i in range(0, self.state["ProblemType"]["NumIndicesC"]): - kStr += " unsigned int const size%s,%s" % (self.indexChars[i], self.endLine) + kStr += "%s unsigned int const size%s" % (midEnd, self.indexChars[i]) - # gsu & SR + # gsu + kStr += "%s unsigned int const gsu" % midEnd + # SR if self.state["ProblemType"]["DestDataType"].is8bitFloat() \ and self.state["ProblemType"]["StochasticRounding"]: - kStr += " unsigned int const gsu,%s" % self.endLine - kStr += " const uint32_t RNDSeed)%s" % self.endLine - else: - kStr += " unsigned int const gsu)%s" % self.endLine + kStr += "%s const uint32_t RNDSeeds" % midEnd + + # put final end + kStr += ")%s" % self.endLine return kStr @@ -161,11 +172,21 @@ def kernelBody(self): kStr += " + (IDX%s)*strideC%s" % (indexChar, indexChar) kStr += " ))" + self.endLine + # define NUM_ELEMENT_LOAD and NUM_GSU for GlobalSplitUSeparatePost + mul_NEL = "" + div_NEL = "" + kStr += "#define NUM_ELEMENT_LOAD %d%s" % (self.state["VectorWidth"], self.endLine) + mul_NEL = "*NUM_ELEMENT_LOAD" + div_NEL = "/NUM_ELEMENT_LOAD" + # parallel reduction + kStr += "#define NUM_REDUCTION %d%s" % (self.state["Reduction"], self.endLine) + div_R = "/NUM_REDUCTION" + ######################################## # multi buffers GSU: Accumulate all GSU buffer indexChar = self.indexChars[0] kStr += " uint64_t id = %s(0);%s" % (self.getGlobalIdStr, self.endLine) - kStr += " if (id >= (size%s" % self.indexChars[0] + kStr += " if (id%s >= (size%s" % (mul_NEL+div_R, self.indexChars[0]) for i in range(1, problemType["NumIndicesC"]): kStr += "*size%s" % self.indexChars[i] kStr += "))%s" % self.endLine @@ -177,9 +198,15 @@ def kernelBody(self): kStr += ", id%d" % i kStr += ";%s" % self.endLine + # parallel reduction + if self.state["Reduction"] > 1: + kStr += " int idR = (int)(id %% NUM_REDUCTION);%s" % (self.endLine) + kStr += " id = id / NUM_REDUCTION;%s" % (self.endLine) for i in range(0, problemType["NumIndicesC"]): - kStr += " id%d = id %% size%s;%s" % (i, self.indexChars[i], self.endLine) - kStr += " id = id / size%s;%s" % (self.indexChars[i], self.endLine) + kStr += " id%d = (id %% (size%s%s))%s;%s" % (i, self.indexChars[i], div_NEL, mul_NEL,self.endLine) + kStr += " id = id / (size%s%s);%s" % (self.indexChars[i], div_NEL, self.endLine) + div_NEL = "" # for first iter only + mul_NEL = "" # for first iter only nonTileFreeIndices = [] @@ -241,35 +268,97 @@ def kernelBody(self): indexChar = self.indexChars[i] kStr += " + (size%s - 1) * strideW%s" % (indexChar, indexChar) kStr += ";" + self.endLine - kStr += " " + self.datatype + " accum = 0;%s" % self.endLine - kStr += "#pragma unroll%s" % self.endLine - kStr += " for (int i=0; i=2 (because VectorWidth > 1 here) + if (storeByte == 2): + storeTypeStr = "tensile_half" + else: + storeTypeStr = "float%u"% (storeByte // 4) + + # parallel reduction + if self.state["Reduction"] > 1: + kStr += " idxW += strideW * idR;%s" % (self.endLine) + kStr += " strideW *= NUM_REDUCTION;%s" % (self.endLine) + + # define accum variable(s) + for vi in range(self.state["VectorWidth"]): + kStr += " %s accum%u = 0;%s" % (self.datatype, vi, self.endLine) + # define result buffer + kStr += " %s result[NUM_ELEMENT_LOAD];%s"%(destTypeStr, self.endLine) + + idxStr = [""] if self.state["VectorWidth"] == 1 else [".x",".y",".z",".w"] # element access for wider load + if self.state["GlobalSplitU"] > self.state["GSUUnrollUnit"]: + # generate loop for large GSU + iterUnit = self.state["GSUUnrollUnit"] // self.state["Reduction"] + minusGSUmod = " - %d"%self.state["GSUmod"] if self.state["GSUmod"] == self.state["GSUUnrollUnit"] else "" + kStr += " uint32_t gsu_div = (gsu%s) / %u;%s" % (minusGSUmod, self.state["GSUUnrollUnit"], self.endLine) + kStr += " for (int i=0; i 1: + r = 1 + while r < self.state["Reduction"]: + for vi in range(self.state["VectorWidth"]): + kStr += " accum%d += __shfl_down(accum%d, %d, %d);%s" % (vi, vi, r, self.state["Reduction"], self.endLine) + r *= 2 + + # do alpha-beta only for idR==0 (representative index) + kStr += " if( idR != 0)%s" % (self.endLine) + kStr += " return;%s" % self.endLine + + #alpha + for vi in range(self.state["VectorWidth"]): + kStr += " accum%d *= (%s)alpha;%s" % (vi, self.datatype, self.endLine) + #Beta + kStr += " if( beta != (%s)0){%s" % (self.datatype, self.endLine) + for vi in range(self.state["VectorWidth"]): + # load C here + kStr += " accum%d += beta * (%s)C[idxC+%d];%s" % (vi, self.datatype, vi, self.endLine) kStr += " }%s" % self.endLine - kStr += " if( beta == (%s)0)%s" % (self.state["ProblemType"]["ComputeDataType"].toDevice(self.language), self.endLine) - kStr += " accum = ((" + self.datatype + ")alpha) * accum;%s" % (self.endLine) - kStr += " else%s" % self.endLine - kStr += " accum = (((" + self.datatype + ")alpha) * accum + ((" + self.datatype + ")beta) * ((" + self.datatype + ")C[idxC]));" + self.endLine - - typeStr = self.state["ProblemType"]["DestDataType"].toDevice(self.language) # Stochastic Rounding? need to use explicit_downcast if self.state["ProblemType"]["DestDataType"].is8bitFloat() \ and self.state["ProblemType"]["StochasticRounding"]: - # generate RND... For F8, computeDataType is always f32 - kStr += " uint32_t x = reinterpret_cast(accum);%s" % (self.endLine) - kStr += " uint32_t drop_bits = x & 0xFFFFu;%s" % (self.endLine) - kStr += " drop_bits ^= x >> 16;%s" % (self.endLine) - kStr += " drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);%s" % (self.endLine) - kStr += " drop_bits *= 0x7000149;%s" % (self.endLine) - kStr += " uint32_t rng = (drop_bits ^ 0x13371337 ^ (idxD * 229791) ^ RNDSeed);%s" % (self.endLine) - - # call explicit_downcast - cmpTypeStr = self.state["ProblemType"]["ComputeDataType"].toDevice(self.language) - kStr += " D[idxD] = explicit_downcast<%s, %s, true>(accum, rng);%s" % (typeStr, cmpTypeStr, self.endLine) + for vi in range(self.state["VectorWidth"]): + # generate RND... For F8, computeDataType is always f32 + kStr += " uint32_t x = reinterpret_cast(accum%d);%s" % (vi, self.endLine) + kStr += " uint32_t drop_bits = x & 0xFFFFu;%s" % (self.endLine) + kStr += " drop_bits ^= x >> 16;%s" % (self.endLine) + kStr += " drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);%s" % (self.endLine) + kStr += " drop_bits *= 0x7000149;%s" % (self.endLine) + kStr += " uint32_t rng = (drop_bits ^ 0x13371337 ^ (idxD * 229791) ^ RNDSeed);%s" % (self.endLine) + + # call explicit_downcast + cmpTypeStr = self.state["ProblemType"]["ComputeDataType"].toDevice(self.language) + kStr += " result[%d] = explicit_downcast<%s, %s, true>(accum%d, rng);%s" % (vi, destTypeStr, cmpTypeStr, vi, self.endLine) else: - kStr += " D[idxD] = (%s)accum;%s" % (typeStr, self.endLine) + #covert to output + for vi in range(self.state["VectorWidth"]): + kStr += " result[%d] = (%s)accum%d;%s" % (vi, destTypeStr, vi, self.endLine) + + kStr += " *((%s*)(D+idxD)) = *((%s*)(result));%s" % (storeTypeStr, storeTypeStr, self.endLine) ######################################## # end @@ -283,6 +372,9 @@ def kernelBody(self): kStr += "#undef GLOBAL_D%s" % (self.endLine) kStr += "#undef GLOBAL_W%s" % (self.endLine) kStr += "#undef GLOBAL_C%s" % (self.endLine) + kStr += "#undef NUM_ELEMENT_LOAD%s" % (self.endLine) + # parallel reduction + kStr += "#undef NUM_REDUCTION%s" % (self.endLine) return kStr @@ -297,7 +389,18 @@ def getKernelName(self): name += self.state["ProblemType"]["DestDataType"].toChar() name += "" if self.state["ProblemType"]["StridedBatched"] else "_GB" name += "_PostGSU" - + # add extra string for gsu (only for GSUUnrollUnit > 1) + if self.state["GSUUnrollUnit"] > 1: + # This part must match client code (in ContractionSolution.cpp) + gsuMod = self.state["GSUmod"] + modStr = "" + if self.state["GlobalSplitU"] > self.state["GSUUnrollUnit"]: + modStr += "_mod%u"%self.state["GSUUnrollUnit"] + name += "%u%s"%(gsuMod, modStr) + if self.state["VectorWidth"] > 1: + name += "_VW" + str(self.state["VectorWidth"]) + if self.state["Reduction"] > 1: + name += "_R" + str(self.state["Reduction"]) return name diff --git a/Tensile/KernelWriterSource.py b/Tensile/KernelWriterSource.py index 1c1289d17..965df5755 100644 --- a/Tensile/KernelWriterSource.py +++ b/Tensile/KernelWriterSource.py @@ -2500,6 +2500,12 @@ def globalReadDo(self, kernel, mode, tP, vregSetIdx=0): self.endLine) return kStr + ############################################################################## + # Global Read A/B completed + ############################################################################## + def doneGlobalABReads(self, kernel): + return "" + ############################################################################## # Local Write: Swap Offsets A/B ############################################################################## diff --git a/Tensile/KernelWriterStreamKInit.py b/Tensile/KernelWriterStreamKInit.py new file mode 100644 index 000000000..da1a75539 --- /dev/null +++ b/Tensile/KernelWriterStreamKInit.py @@ -0,0 +1,121 @@ +################################################################################ +# +# Copyright (C) 2023 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ + +from copy import deepcopy + +from .Common import globalParameters, CHeader +from .KernelWriterBase import KernelWriterBase + +class KernelWriterStreamKInit(KernelWriterBase): + + def __init__(self, state): + super().__init__() + + self.state["ProblemType"] = deepcopy(state["ProblemType"]) + self.state["_GlobalAccumulation"] = state["_GlobalAccumulation"] + + # derive parameter + self.language = "HIP" + self.kernelName = self.getKernelName() + + + def functionSignature(self): + kStr = "" + + # self.state name + kStr += self.endLine + kStr += "extern \"C\"" + self.endLine + kStr += "__global__ " + kStr += "void %s" % ( self.kernelName ) + kStr += "(" + self.endLine + + # pointers + kStr += " unsigned int * Flags," + self.endLine # Already offset to start of flags section in workspace + + kStr += " unsigned int const flagCount" + self.endLine + + kStr += " )%s" % (self.endLine) + + return kStr + + + ############################################################################## + # Kernel Body Stream-K Init + ############################################################################## + def kernelBodyStreamKInit(self): + kStr = "" + kStr += "{%s" % self.endLine + + ######################################## + # Stream-K initialize flags to 0 + kStr += " uint64_t id = %s(0);%s" % (self.getGlobalIdStr, self.endLine) + kStr += " if (id >= (flagCount))" + self.endLine + kStr += " return;%s" % self.endLine + kStr += self.endLine + + kStr += " Flags[id] = 0;" + self.endLine + + ######################################## + # end + kStr += "}%s" % self.endLine + + return kStr + + + def getKernelName(self): + # Output to workspace flags + name = "WSFlags" + # name += "_" + # name += self.state["ProblemType"]["DestDataType"].toChar() + return name + + + def getSourceFileString(self): + fileString = "" + + if not globalParameters["MergeFiles"]: + fileString += "\n" + fileString += "#include \"%s.h\"\n" % self.kernelName + fileString += "\n" + + fileString += self.functionSignature() + fileString += self.kernelBodyStreamKInit() + + return (0, fileString) + + def getHeaderFileString(self): + fileString = "" # CHeader + if not globalParameters["MergeFiles"]: + fileString += CHeader + fileString += "#pragma once\n\n" + fileString += "\n" + fileString += "#include \n\n" + fileString += "#include \n" + fileString += "#include \n" + fileString += "\n" + + fileString += self.functionSignature() + fileString += ";\n" + + return fileString diff --git a/Tensile/LibraryIO.py b/Tensile/LibraryIO.py index 4661e98d9..916c00916 100644 --- a/Tensile/LibraryIO.py +++ b/Tensile/LibraryIO.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2016-2022 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -179,6 +179,9 @@ def parseLibraryLogicData(data, srcFile="?"): if "CUCount" not in data: data["CUCount"] = None + if "IsAPU" not in data: + data["IsAPU"] = None + if "Fp16AltImpl" not in data: data["Fp16AltImpl"] = False @@ -231,10 +234,14 @@ def parseLibraryLogicList(data, srcFile="?"): if type(data[2]) is dict: rv["ArchitectureName"] = data[2]["Architecture"] - rv["CUCount"] = data[2]["CUCount"] + if "CUCount" in data[2]: + rv["CUCount"] = data[2]["CUCount"] + if "IsAPU" in data[2]: + rv["IsAPU"] = data[2]["IsAPU"] else: rv["ArchitectureName"] = data[2] rv["CUCount"] = None + rv["IsAPU"] = None rv["ExactLogic"] = data[7] # data[8] previously contained range logic, which has been retired @@ -300,7 +307,10 @@ def createLibraryLogic(schedulePrefix, architectureName, deviceNames, logicTuple # architecture if type(architectureName) is dict: rv["ArchitectureName"] = architectureName["Architecture"] - rv["CUCount"] = architectureName["CUCount"] + if "CUCount" in architectureName: + rv["CUCount"] = architectureName["CUCount"] + if "IsAPU" in architectureName: + rv["IsAPU"] = architectureName["IsAPU"] else: rv["ArchitectureName"] = architectureName diff --git a/Tensile/SolutionLibrary.py b/Tensile/SolutionLibrary.py index c77ad385a..558727049 100644 --- a/Tensile/SolutionLibrary.py +++ b/Tensile/SolutionLibrary.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -276,17 +276,21 @@ def FromOriginalState(cls, def hardware(d, problemType, solutions, library, placeholderName): devicePart = d["ArchitectureName"] cuCount = d["CUCount"] + isAPU = d["IsAPU"] newLib = PredicateLibrary(tag="Hardware") if devicePart == "fallback": pred = Hardware.HardwarePredicate("TruePred") else: - pred = Hardware.HardwarePredicate.FromHardware(Common.gfxArch(devicePart), cuCount) + pred = Hardware.HardwarePredicate.FromHardware(Common.gfxArch(devicePart), cuCount, isAPU) newLib.rows.append({"predicate": pred, "library": library}) if lazyLibrary: - if cuCount: placeholderName += "_CU" + str(cuCount) + if cuCount: + placeholderName += "_CU" + str(cuCount) + if isAPU is not None: + placeholderName += "_APU" if isAPU == 1 else "_XPU" placeholderName += "_" + str(devicePart) return newLib, placeholderName diff --git a/Tensile/SolutionStructs.py b/Tensile/SolutionStructs.py index 4a2159a24..0239d23ff 100644 --- a/Tensile/SolutionStructs.py +++ b/Tensile/SolutionStructs.py @@ -32,6 +32,7 @@ from .DataType import DataType from .Utils import roundUpToNearestMultiple +from .KernelWriterStreamKInit import KernelWriterStreamKInit from .KernelWriterBetaOnly import KernelWriterBetaOnly from .KernelWriterConversion import KernelWriterConversion @@ -1781,15 +1782,28 @@ def getKernels(self): ######################################## # create Helper Kernels def initHelperKernelObjects(self): + self.initStreamKInitKernelObjects() self.initBetaOnlyKernelObjects() self.initConversionKernelObjects() ######################################## - # create BetaONly Kernels + # create StreamKInit Kernels + def initStreamKInitKernelObjects(self): + self.streamKInitKernelObjects = [] + if self["StreamK"] == 2 or self["StreamK"] == 3: + state = {} + state["ProblemType"] = deepcopy(self["ProblemType"]) + state["KernelLanguage"] = "Source" + state["_GlobalAccumulation"] = self["_GlobalAccumulation"] + self.streamKInitKernelObjects.append(KernelWriterStreamKInit(state)) + + + ######################################## + # create BetaOnly Kernels def initBetaOnlyKernelObjects(self): self.betaOnlyKernelObjects = [] - if self["GlobalSplitU"] > 1: + if self["GlobalSplitU"] > 1 or self["StreamK"] == 1: state = {} state["ProblemType"] = deepcopy(self["ProblemType"]) state["KernelLanguage"] = "Source" @@ -1801,23 +1815,63 @@ def initBetaOnlyKernelObjects(self): # create Conversion Kernels def initConversionKernelObjects(self): self.conversionKernelObjects = [] - if (self["GlobalSplitU"] > 1) and self["_GlobalAccumulation"]: - state = {} - state["ProblemType"] = deepcopy(self["ProblemType"]) - state["KernelLanguage"] = "Source" - state["_GlobalAccumulation"] = self["_GlobalAccumulation"] - self.conversionKernelObjects.append(KernelWriterConversion(state)) + gsu = self["GlobalSplitU"] + if (gsu > 1) and self["_GlobalAccumulation"]: + # wider load for GSU is single compute type only + supportedTypeForVWopt = self["ProblemType"]["ComputeDataType"].isSingle() or self["ProblemType"]["ComputeDataType"].isDouble() + vwMax = 1 + if (supportedTypeForVWopt): + vwMax = 2 + + # reduction for GSU is single compute type + gus = power of 2 only + supportedTypeForReductionOpt = self["ProblemType"]["ComputeDataType"].isSingle() or self["ProblemType"]["ComputeDataType"].isDouble() + maxReduction = 1 + maxReductionConst = 4 # this must match the value in client code (ContractionSolution.cpp) + minGSUperReduction = 32; # Minimum GSU=128 for Reduction=4, GSU=64 for Reduction2 + applicableReduction = max(1, gsu // minGSUperReduction) + if (supportedTypeForReductionOpt and ((gsu & (gsu - 1)) == 0) and self["_GlobalAccumulation"] == "MultipleBuffer"): + maxReduction = min(applicableReduction, maxReductionConst) # not exceeding reductionThreshold + + # loop unroll opt for postGSU + supportedTypeForUnrollOpt = self["ProblemType"]["ComputeDataType"].isSingle() or self["ProblemType"]["ComputeDataType"].isDouble() + + vw = 1 + while vw <= vwMax: + reduction = 1 + while reduction <= maxReduction: + # so far, reduction=2 does not perform well. Skip 2 + if reduction == 2 and self["ProblemType"]["ComputeDataType"].isSingle(): + reduction *= 2 + continue + state = {} + state["ProblemType"] = deepcopy(self["ProblemType"]) + state["KernelLanguage"] = "Source" + state["_GlobalAccumulation"] = self["_GlobalAccumulation"] + state["GlobalSplitU"] = self["GlobalSplitU"] + state["VectorWidth"] = vw + state["Reduction"] = reduction + # number of unroll for large GSU (must match client code) + state["GSUUnrollUnit"] = 16 * state["Reduction"] if supportedTypeForUnrollOpt and self["_GlobalAccumulation"] == "MultipleBuffer" else 1 + self.conversionKernelObjects.append(KernelWriterConversion(state)) + reduction *= 2 + vw *= 2 ######################################## # get Helper Kernels def getHelperKernelObjects(self): - return self.betaOnlyKernelObjects + self.conversionKernelObjects + return self.streamKInitKernelObjects + self.betaOnlyKernelObjects + self.conversionKernelObjects ######################################## # get Helper Kernels - def getKernelBetaOlnyObjects(self): + def getKernelStreamKInitObjects(self): + return self.streamKInitKernelObjects + + + ######################################## + # get Helper Kernels + def getKernelBetaOnlyObjects(self): return self.betaOnlyKernelObjects @@ -2419,23 +2473,15 @@ def isDirectToVgprDoable(state, tc): reject(state, "DirectToVgpr%c does not supports TLU%c = True + numByte < 4"%(tc, tc)) return False - # MIWaveGroup check - # for A, MIWaveGroup[1] should be 1 - # for B, MIWaveGroup[0] should be 1 + # MIWaveGroup, MatrixInstBM,BN check + # for A, MIWaveGroup[1] and MatrixInstBN should be 1 + # for B, MIWaveGroup[0] and MatrixInstBM should be 1 # This is to limit the number of Vgpr - if tc == 'A' and not (state['MIWaveGroup'][1] == 1): - reject(state, "MIWaveGroup[1] should be 1 for DirectToVgprA. Current value is [%s]"%state['MIWaveGroup'][1]) + if tc == 'A' and not (state['MIWaveGroup'][1] == 1 and state['MatrixInstBN'] == 1): + reject(state, "MIWaveGroup[1] and MatrixInstBN should be 1 for DirectToVgprA. Current value is [%d, %d]"%(state['MIWaveGroup'][1], state['MatrixInstBN'])) return False - if tc == 'B' and not (state['MIWaveGroup'][0] == 1): - reject(state, "MIWaveGroup[0] should be 1 for DirectToVgprB. Current value is [%s]"%state['MIWaveGroup'][0]) - return False - - # Does not support MatrixInstBM, MatrixInstBN > 1 - if state['MatrixInstBM'] > 1: - reject(state, "MatrixInstBM should be 1 for DirectToVgpr. Current value is %s"%state['MatrixInstBM']) - return False - if state['MatrixInstBN'] > 1: - reject(state, "MatrixInstBN should be 1 for DirectToVgpr. Current value is %s"%state['MatrixInstBN']) + if tc == 'B' and not (state['MIWaveGroup'][0] == 1 and state['MatrixInstBM'] == 1): + reject(state, "MIWaveGroup[0] and MatrixInstBM should be 1 for DirectToVgprB. Current value is [%d, %d]"%(state['MIWaveGroup'][0], state['MatrixInstBM'])) return False # Does not work with WaveSeparateGlobalRead @@ -2675,10 +2721,173 @@ def getDivisorName(state, tC): divisorName = "LVP{}".format(tC) return divisorName + ######################################## + # get number of elements in LDS for each tc (A or B) + @staticmethod + def getLdsNumElements(state, tc): + bpeAB = int(4*state["ProblemType"]["DataType"].numRegisters()) + if state["LdsBlockSizePerPad%s"%tc]: + padInterval = state["LdsBlockSizePerPad%s"%tc] // bpeAB + else: + if state["UnrollMajorLDS%s"%tc]: + padInterval = state["_DepthULds"] + else: + padInterval = state["MacroTile%s"%tc] + numElements = state["MacroTile%s"%tc] * state["_DepthULds"] + ldsNumElements = numElements + (numElements // padInterval) * (state["LdsPad%s"%tc]) + if state["DirectToVgpr%s"%tc]: + # DirectToVgpr does not use LDS. Set to 0. + ldsNumElements = 0 + return ldsNumElements + + ######################################## + # get number of aligned elements in LDS for each tc (A or B) + @staticmethod + def getLdsNumElementsAligned(state, tc): + ldsAlign = int(64 / state["ProblemType"]["DataType"].numRegisters()) + ldsNumElements = Solution.getLdsNumElements(state, tc) + ldsNumElementsAligned = roundUpToNearestMultiple(ldsNumElements, ldsAlign) + return ldsNumElementsAligned + + + ######################################## + # determine auto LdsPad and LdsBlockSizePerPad + @staticmethod + def ldsPaddingAuto(state, isa): + # LDS padding + # Resolve -1 before isDirectToLdsDoable check + numBytes = state["ProblemType"]["DataType"].numBytes() + optPad = state["LocalReadVectorWidth"] + readRegs = state["LocalReadVectorWidth"]*numBytes//4 + if (not globalParameters["AsmCaps"][isa]['HasWMMA']) and readRegs > 4: + reject(state, "LocalReadVectorWidth=%u results in attemping to read LDS larger than b128, reject") + + autoAdjusted = {"LdsPadA": False, "LdsPadB": False, "LdsBlockSizePerPadA": False, "LdsBlockSizePerPadB": False} + for tc in ('A','B'): + # set pad as readRegs to avoid unaligned read + idx01 = 0 if tc == 'A' else 1 + charMN = 'M' if tc == 'A' else 'N' + numBank = 32 + + # LdsBlockSizePerPad and LdsPad setting + autoCalcLBSPP = False + if state["LdsBlockSizePerPad%s"%tc] == -1: + state["LdsBlockSizePerPad%s"%tc] = 0 + autoCalcLBSPP = True + autoAdjusted["LdsBlockSizePerPad%s"%tc] = autoCalcLBSPP + autoCalcLP = False + if state["LdsPad%s"%tc] == -1: + autoCalcLP = True + if state["ProblemType"]["TLU%s"%tc] and (not state["UnrollMajorLDS%s"%tc]): + state["LdsPad%s"%tc] = 0 + else: + state["LdsPad%s"%tc] = state["VectorWidth"] + autoAdjusted["LdsPad%s"%tc] = autoCalcLBSPP + + if state["EnableMatrixInstruction"]: + # MI case + LRstrideLine = 0 # for LBSPP value check + LRstride = 0 + comment = "" + depthU = state["_DepthULds"] + vw = state["VectorWidth"] if tc=='A' else state["VectorWidthB"] + if not state["SourceSwap"]: + vw = 1 # TODO: support non-SourceSwap + vw + if state["UnrollMajorLDS%s"%tc]: + LRstrideLine = state["_DepthULds"] + comment = "DepthULds" + # if depthU is not power of 2, adjust ldsPad at each line (keep LRstride = 0) + if not (depthU > 0 and (depthU & (depthU - 1)) != 0): + LRstride = LRstrideLine * vw + else: + LRstrideLine = state["MacroTile%d"%idx01] + comment = "MT0" + if state["MIInputPerThread"] > 1: + # MIInputPerThread > 1 case, we still need padding to mitigate bank conflict even for non-UnrollMajorLDS case + LRstride = LRstrideLine * state["LocalReadVectorWidth"] + # auto calc for LBSPP + if autoCalcLBSPP and LRstride > 0: + state["LdsBlockSizePerPad%s"%tc] = max(int(2**(math.ceil(math.log(LRstride * numBytes, 2)))), 128) + # value check + if state["LdsBlockSizePerPad%s"%tc]: + if state["LdsBlockSizePerPad%s"%tc] < LRstrideLine: + reject(state, "reject: %s %u x bpe > LdsBlockSizePerPad%s %u" % (comment, LRstrideLine, tc, state["LdsBlockSizePerPad%s"%tc])) + # auto calc for LdsPad + if autoCalcLP: + miWidth = state["MatrixInst%s"%charMN] * state["MatrixInstB%s"%charMN] + if tc == 'B' and state["MatrixInstBM"] > 1: + # B and state["MatrixInstBM"] > 1 case, BN is not continuous. Use MatrixInstN only + miWidth = state["MatrixInstN"] + if (not state["UnrollMajorLDS%s"%tc]): + extra = 0 + if miWidth < numBank and miWidth * vw < 128 // numBytes: + extra = (miWidth * vw) % (128 // numBytes) + if extra: + divisor = 128//numBytes + mod = LRstride % divisor + state["LdsPad%s"%tc] = (divisor + extra - mod) % divisor + else: + optPadMN = optPad + # for readRegs = 1 or 4, we need to double pad for MI16x16xNx1 to avoid bank conflict. + if miWidth < numBank and (readRegs == 4 or readRegs == 1): + # UnrollMajorLds and miWidth < 32 case, same M or N location (with K+1(,2,...) is accessed + # We need to offset (32//miWidth) times for the next M or N + # ex.) MI16x16, miWidth=16. To avoid Bank conflict, we need to double padding for M1. + # [M0,K0]. [M0,K1] + # offset K0 and K1 -> [M1,K0]. [M1,K1] + # offset K0 and K1 -> [M2,K0]. [M2,K1] + # offset K0 and K1 -> [M3,K0]. [M3,K1] + optPadMN = optPad * (numBank//miWidth) + optPadMN = max(state["GlobalLoadVectorWidth%s"%tc], optPadMN) + # if depthU is not power of 2, adjust ldsPad to make depthU part multiple of 128/numBytes + if depthU > 0 and (depthU & (depthU - 1)) != 0: + divisor = 128//numBytes + ratio = roundupRatio(128//numBytes, depthU) + extraPadDU = (ratio * depthU) % (128//numBytes) + optPadMN_Plus = (extraPadDU + optPadMN) % (128//numBytes) + optPadMN_Minus = (extraPadDU - optPadMN) % (128//numBytes) + optPadMN = min(optPadMN_Plus, optPadMN_Minus) + state["LdsPad%s"%tc] = optPadMN + ## turn-off padding for directToLds + if state["DirectToLds%s"%tc]: + state["LdsPad%s"%tc] = 0 + + else: + # non MI case + if state["UnrollMajorLDS%s"%tc]: + reject(state, "didn't support UnrollMajorLDS in VALU mode yet") + if state["LdsBlockSizePerPad%s"%tc] != 0: + reject(state, "didn't support LdsBlockSizePerPad in VALU mode yet") + + assert(state["LdsPad%s"%tc] >= 0) + + # set LdsBlockSizePerPad = 0 if LdsPad is 0 + if state["LdsPad%s"%tc] == 0: + state["LdsBlockSizePerPad%s"%tc] = 0 + + # LDS size check for auto adjustment + # if LDS size is over the limit, change LdsPad to 0 + # check UnrollMajorLDS=False for A, B first, then true + for umlds in [False, True]: + for tc in ('A','B'): + # A side (aligned) + B side (not aligned) + ldsNumElementsAB = Solution.getLdsNumElementsAligned(state, 'A') + \ + Solution.getLdsNumElements(state, 'B') + if state["1LDSBuffer"] != 1: + # not 1LDSBuffer case, double num element + ldsNumElementsAB *= 2 + if autoAdjusted["LdsPad%s"%tc] and (ldsNumElementsAB * numBytes) > globalParameters["MaxLDS"]: + # auto adjusted LdsPad and LDS overflow + if state["UnrollMajorLDS%s"%tc] == umlds: + # change LdsPad and LdsBlockSizePerPad to 0 + state["LdsPad%s"%tc] = 0 + state["LdsBlockSizePerPad%s"%tc] = 0 + ######################################## # assign all derived parameters @staticmethod def assignDerivedParameters(state): + isa = tuple(state["ISA"]) state["EnableF32XdlMathOp"] = False #ignore the F32 xDL MathOp by default. #enable F32 xDL MathOp only when the input type is f32. @@ -2704,6 +2913,12 @@ def assignDerivedParameters(state): state["_GlobalAccumulation"] = None state["_WorkspaceSizePerElemC"] = 0 + if state["StreamK"] == 2 or state["StreamK"] == 3: + # StreamK Workspace size + computeBytes = state["ProblemType"]["ComputeDataType"].numBytes() + state["_GlobalAccumulation"] = 'PartialsBuffer' + state["_WorkspaceSizePerElemC"] = computeBytes + if state["GlobalSplitU"] > 1: computeName = state["ProblemType"]["ComputeDataType"].toName() computeBytes = state["ProblemType"]["ComputeDataType"].numBytes() @@ -2720,6 +2935,25 @@ def assignDerivedParameters(state): state["_GlobalAccumulation"] = 'MultipleBuffer' state["_WorkspaceSizePerElemC"] = computeBytes * state["GlobalSplitU"] + if state["StreamK"] != 0: + if state["MIWaveGroup"][0] * state["MIWaveGroup"][1] != 4: + reject(state, "Stream-K requries MIWaveGroup0*MIWaveGroup1=4") + if state["EnableMatrixInstruction"] and globalParameters["AsmCaps"][isa]["HasWMMA"]: + reject(state, "Stream-K untested with WMMA") + if state["GlobalSplitU"] > 1: + reject(state, "Cannot enable both Stream-K and GSU") + if state["PersistentKernel"]: + reject(state, "Cannot enable both Stream-K and PersistentKernel") + if not (2 in state["AssertSizeEqual"].keys() and state["AssertSizeEqual"][2] == 1): + reject(state, "Stream-K with batch requires further testing") + if state["StreamK"] == 1: + if not state["ProblemType"]["DataType"].isSingle(): + reject(state, "Atomic Stream-K currently only tested for SGEMM") + if not state["BufferStore"]: + reject(state, "Atomic Stream-K requires BufferStore") + if state["LocalSplitU"] > 1: + reject(state, "Atomic Stream-K not working with LocalSplitU") + if state["VectorStore"] == -1: state["_VectorStore"] = 1 # default, may be changed if needed to generate a valid kernel @@ -2728,7 +2962,7 @@ def assignDerivedParameters(state): print2("in assignDerivedParameters, state['Valid'] = False") return - atomic = ((state["GlobalSplitU"] > 1) and (state["_GlobalAccumulation"] != 'MultipleBuffer')) or state["AtomicAddC"] + atomic = ((state["GlobalSplitU"] > 1) and (state["_GlobalAccumulation"] != 'MultipleBuffer')) or state["AtomicAddC"] or state["StreamK"] == 1 if atomic and globalParameters["DebugSkipAtomic"]: reject(state, "DEBUG: DebugSkipAtomic enabled, rejecting atomic kernel") if not atomic and globalParameters["DebugSkipNonAtomic"]: @@ -2747,8 +2981,6 @@ def assignDerivedParameters(state): state["1LDSBuffer"] = 1 print2("\nSet SIA=2, force PrefetchLocalRead=1, ExpandPointerSwap=1, 1LDSBuffer=1") - isa = tuple(state["ISA"]) - if "MemoryModifierFormat" not in state or state["MemoryModifierFormat"] not in validParameters["MemoryModifierFormat"]: if globalParameters["AsmCaps"][isa]["HasGLCModifier"]: state["MemoryModifierFormat"] = "GLC" @@ -3252,6 +3484,10 @@ def assignDerivedParameters(state): else: state["LocalReadVectorWidth"] = state["VectorWidth"] + # reject if FractionalLoad and depthU is not power of 2 (does not work) + if state["FractionalLoad"] and (depthU & (depthU - 1)) != 0: + reject(state, "FractionalLoad requires DepthU = power of 2") + ######################################## # Search DepthU # Inputs: @@ -3526,43 +3762,21 @@ def assignDerivedParameters(state): pvar(state, "LVCB"), pvar(state, "LVPB")) # lds buffer size for A, B - if state["KernelLanguage"] == "Source" and \ - state["LdsPadA"] != state["LdsPadB"]: - reject(state, "Source KernelLanguage only supports LdsPadA == LdsPadB") - return + if state["KernelLanguage"] == "Source": + # source kernel + # use 0 for auto adjust (-1) + for tc in ('A','B'): + if state["LdsBlockSizePerPad%s"%tc] == -1: + state["LdsBlockSizePerPad%s"%tc] = 0 + if state["LdsPad%s"%tc] == -1: + state["LdsPad%s"%tc] = 0 + if state["LdsPadA"] != state["LdsPadB"]: + reject(state, "Source KernelLanguage only supports LdsPadA == LdsPadB") + return ######################################## # LDS ######################################## - if state["LdsBlockSizePerPad"] == -1: - if state["MatrixInstruction"] and (state["UnrollMajorLDSA"] or state["UnrollMajorLDSB"]): - state["LdsBlockSizePerPad"] = 128 - if state["_DepthULds"]*state["ProblemType"]["DataType"].numBytes() > state["LdsBlockSizePerPad"]: - state["LdsBlockSizePerPad"] = int(2**(math.ceil(math.log(state["_DepthULds"]*state["ProblemType"]["DataType"].numBytes(), 2)))) - else: - state["LdsBlockSizePerPad"] = 0 - - state["LdsBlockSizePerPadA"] = state["LdsBlockSizePerPad"] if state["UnrollMajorLDSA"] else 0 - state["LdsBlockSizePerPadB"] = state["LdsBlockSizePerPad"] if state["UnrollMajorLDSB"] else 0 - - if state["EnableMatrixInstruction"]: - if state["LdsBlockSizePerPadA"]: - if not state["UnrollMajorLDSA"]: - reject(state, "didn't support LdsBlockSizePerPadA on tile major LDS yet") - if state["LdsBlockSizePerPadA"] < state["_DepthULds"]*state["ProblemType"]["DataType"].numBytes(): - reject(state, "reject: DepthULds %u x bpe > LdsBlockSizePerPadA %u" % (state["_DepthULds"], state["LdsBlockSizePerPad"])) - - if state["LdsBlockSizePerPadB"]: - if not state["UnrollMajorLDSB"]: - reject(state, "didn't support LdsBlockSizePerPadB on tile major LDS yet") - if state["LdsBlockSizePerPadB"] < state["_DepthULds"]*state["ProblemType"]["DataType"].numBytes(): - reject(state, "reject: DepthULds %u x bpe > LdsBlockSizePerPadB %u" % (state["_DepthULds"], state["LdsBlockSizePerPad"])) - else: - if state["UnrollMajorLDSA"] or state["UnrollMajorLDSB"]: - reject(state, "didn't support UnrollMajorLDS in VALU mode yet") - if state["LdsBlockSizePerPadA"] != 0 or state["LdsBlockSizePerPadB"] != 0: - reject(state, "didn't support LdsBlockSizePerPad in VALU mode yet") - # allow LocalReadVectorWidthB > 1 for TLUB + MatrixInstruction (this is applicable for B only) # some more limitations necessary to make this logic work # - MatrixInstruction @@ -3616,6 +3830,10 @@ def assignDerivedParameters(state): # disable DTL state["DirectToLdsB"] = False + # LDS padding + # Resolve -1 before isDirectToLdsDoable check + Solution.ldsPaddingAuto(state, isa) + # Determine if we can load directly-to-LDS. # Transpose requires a trip through registers to perform the transpose so can't use DirectToLdsA # LDS loads always write 4 bytes apart so can use only 4-byte operations @@ -3703,73 +3921,10 @@ def assignDerivedParameters(state): reject(state, "PAPM + PGR=2 does not work if AssertSizeGreaterThan for K is not greater than DepthU * 2 - 1") return - # set pad as readRegs to avoid unaligned read - optPad = state["LocalReadVectorWidth"] - readRegs = state["LocalReadVectorWidth"]*state["ProblemType"]["DataType"].numBytes()//4 - if (not globalParameters["AsmCaps"][isa]['HasWMMA']) and readRegs > 4: - reject(state, "LocalReadVectorWidth=%u results in attemping to read LDS larger than b128, reject") - - if state["EnableMatrixInstruction"]: - # for readRegs = 1 or 4, we need to double pad for MI16x16xNx1 to avoid bank conflict. - if state["MatrixInstB"] == 1 and state["MatrixInstM"] == 16 and \ - (readRegs == 4 or readRegs == 1): - optPad *= 2 - if state["LdsPadA"] == -1: - if state["ProblemType"]["TLUA"] and (not state["UnrollMajorLDSA"]): - state["LdsPadA"] = 0 - else: - if state["EnableMatrixInstruction"] and state["UnrollMajorLDSA"]: - state["LdsPadA"] = max(state["GlobalReadVectorWidth"],optPad) - else: - state["LdsPadA"] = state["VectorWidth"] - ## turn-off padding for directToLds - if state["EnableMatrixInstruction"] and state["UnrollMajorLDSA"] and state["DirectToLdsA"]: - state["LdsPadA"] = 0 - assert(state["LdsPadA"] >= 0) - if state["LdsPadB"] == -1: - if state["ProblemType"]["TLUB"] and (not state["UnrollMajorLDSB"]): - state["LdsPadB"] = 0 - else: - if state["EnableMatrixInstruction"] and state["UnrollMajorLDSB"]: - state["LdsPadB"] = max(state["GlobalReadVectorWidth"],optPad) - else: - state["LdsPadB"] = state["VectorWidth"] - if state["EnableMatrixInstruction"] and state["UnrollMajorLDSB"] and state["DirectToLdsB"]: - state["LdsPadB"] = 0 - assert(state["LdsPadB"] >= 0) - - if (state["UnrollMajorLDSA"] or state["UnrollMajorLDSB"]) and (not state["EnableMatrixInstruction"]): - reject(state, "UnrollMajorLDS Supports only in EnableMatrixInstruction=1") - - ldsAlign = int(64 / state["ProblemType"]["DataType"].numRegisters()) - - if state["UnrollMajorLDSA"]: - ldsNumElementsA = (state["_DepthULds"] + state["LdsPadA"]) * state["MacroTileA"] - padInterval = state["LdsBlockSizePerPadA"] // bpeAB - if padInterval != 0: - ldsNumElementsA = int((state["_DepthULds"] * state["MacroTileA"]) / padInterval * (padInterval + state["LdsPadA"])) - ldsNumElementsAlignedA = roundUpToNearestMultiple(ldsNumElementsA, ldsAlign) - else: - ldsNumElementsA = state["_DepthULds"] * (state["MacroTileA"] + state["LdsPadA"]) - ldsNumElementsAlignedA = roundUpToNearestMultiple(ldsNumElementsA, ldsAlign) - if state["DirectToVgprA"]: - # DirectToVgpr does not use LDS. Set to 0. - ldsNumElementsA = 0 - ldsNumElementsAlignedA = 0 - - if state["UnrollMajorLDSB"]: - ldsNumElementsB = (state["_DepthULds"] + state["LdsPadB"]) * state["MacroTileB"] - padInterval = state["LdsBlockSizePerPadB"] // bpeAB - if padInterval != 0: - ldsNumElementsB = int((state["_DepthULds"] * state["MacroTileB"]) / padInterval * (padInterval + state["LdsPadB"])) - ldsNumElementsAlignedB = roundUpToNearestMultiple(ldsNumElementsB, ldsAlign) - else: - ldsNumElementsB = state["_DepthULds"] * (state["MacroTileB"] + state["LdsPadB"]) - ldsNumElementsAlignedB = roundUpToNearestMultiple(ldsNumElementsB, ldsAlign) - if state["DirectToVgprB"]: - # DirectToVgpr does not use LDS. Set to 0. - ldsNumElementsB = 0 - ldsNumElementsAlignedB = 0 + #ldsNumElementsA = Solution.getLdsNumElements(state, 'A') # not used + ldsNumElementsB = Solution.getLdsNumElements(state, 'B') + ldsNumElementsAlignedA = Solution.getLdsNumElementsAligned(state, 'A') + ldsNumElementsAlignedB = Solution.getLdsNumElementsAligned(state, 'B') # todo, can the alignment be a power of 2? state["LdsOffsetA"] = 0 @@ -4384,6 +4539,13 @@ def assignDerivedParameters(state): (state["ThreadTile0"] == 4 and state["ThreadTile1"] == 8)): reject(state, "UnrollLoopEfficiencyEnable does not support ThreadTile0,1 = [%u,%u]"%(state["ThreadTile0"], state["ThreadTile1"])) + # reject check for ClusterLocalRead + if state["ClusterLocalRead"]: + # Requires VgprForLocalReadPacking + if not state["VgprForLocalReadPacking"]: + reject(state, "ClusterLocalRead works with VgprForLocalReadPacking") + return + # reject check for VgprForLocalReadPacking if state["VgprForLocalReadPacking"]: # MatrixInstruction only diff --git a/Tensile/SolutionWriter.py b/Tensile/SolutionWriter.py index d88cd1d32..6780c52cd 100644 --- a/Tensile/SolutionWriter.py +++ b/Tensile/SolutionWriter.py @@ -191,7 +191,7 @@ def getProblemSourceString(self, problemType, solution, kernelsWithBuildErrs): s += "%s}\n" % (t) if gsu > 1: - for ko in solution.getKernelBetaOlnyObjects(): + for ko in solution.getKernelBetaOnlyObjects(): kernelName = ko.getKernelName(ko) s += "%scl_kernel kernel_%s;\n" % (t, kernelName) s += "%s tensileGetCompiledOpenCLKernel(\n" % (t) @@ -469,7 +469,7 @@ def getProblemSourceString(self, problemType, solution, kernelsWithBuildErrs): ######################################## if gsu > 1 and kernel["_GlobalAccumulation"] != 'MultipleBuffer': kernelNamesBetaOnly = [] - for ko in solution.getKernelBetaOlnyObjects(): + for ko in solution.getKernelBetaOnlyObjects(): kernelName = ko.getKernelName() kernelNamesBetaOnly.append(kernelName) s += "%s// enqueue Beta-Only kernel\n" % (t) diff --git a/Tensile/Source/TensileTypes.h b/Tensile/Source/TensileTypes.h index cb26bf17f..a1d8c29db 100644 --- a/Tensile/Source/TensileTypes.h +++ b/Tensile/Source/TensileTypes.h @@ -1,6 +1,6 @@ /******************************************************************************* * - * Copyright (C) 2016-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -174,6 +174,13 @@ constexpr __host__ __device__ bool operator==(tensile_complex const& a, return (a.x == b.x) && (a.y == b.y); } +template +constexpr __host__ __device__ bool operator!=(tensile_complex const& a, + tensile_complex const& b) +{ + return (a.x != b.x) || (a.y != b.y); +} + using tensile_float_complex = tensile_complex; using tensile_double_complex = tensile_complex; diff --git a/Tensile/Source/client/main.cpp b/Tensile/Source/client/main.cpp index cd26f3820..dd1f0641a 100644 --- a/Tensile/Source/client/main.cpp +++ b/Tensile/Source/client/main.cpp @@ -451,9 +451,9 @@ namespace Tensile if(solutionIterator->runCurrentSolution()) { - maxWorkspaceSize - = std::max(maxWorkspaceSize, - solution->requiredWorkspaceSize(problems[problemIdx])); + maxWorkspaceSize = std::max( + maxWorkspaceSize, + solution->requiredWorkspaceSize(problems[problemIdx], *hardware)); } listeners.postSolution(); diff --git a/Tensile/Source/client/source/HardwareMonitor.cpp b/Tensile/Source/client/source/HardwareMonitor.cpp index cb2b409b2..0f142dec8 100644 --- a/Tensile/Source/client/source/HardwareMonitor.cpp +++ b/Tensile/Source/client/source/HardwareMonitor.cpp @@ -96,11 +96,13 @@ namespace Tensile hipDeviceIndex)); } #endif - + // PCIID format changes from 32bit to 64bit, Below is the new PCI format ID from ROCm 4.0. + // Note:FUNCTION[0:2]bits is only used by aquanjaram TPX mode. This is not supported in HIP API yet, will need modification in future. + // BDFID = ((DOMAIN & 0xffffffff) << 32) | ((BUS & 0xff) << 8) | ((DEVICE & 0x1f) <<3 ) | (FUNCTION & 0x7) uint64_t hipPCIID = 0; - hipPCIID |= props.pciDeviceID & 0xFF; - hipPCIID |= ((props.pciBusID & 0xFF) << 8); - hipPCIID |= (props.pciDomainID) << 16; + hipPCIID |= (((uint64_t)props.pciDomainID & 0xffffffff) << 32); + hipPCIID |= ((props.pciBusID & 0xff) << 8); + hipPCIID |= ((props.pciDeviceID & 0x1f) << 3); uint32_t smiCount = 0; diff --git a/Tensile/Source/client/source/SolutionIterator.cpp b/Tensile/Source/client/source/SolutionIterator.cpp index 5c5d83fb8..9a2147e24 100644 --- a/Tensile/Source/client/source/SolutionIterator.cpp +++ b/Tensile/Source/client/source/SolutionIterator.cpp @@ -88,6 +88,7 @@ namespace Tensile // Test if the persistent kernel is eligible for the current hw and solution m_problem.checkPersistentKernelEligibility(solution, *m_hardware); + m_problem.checkRequiredWorkspaceSize(solution, *m_hardware); if(!(*solution.problemPredicate)(m_problem)) { m_reporter->report(ResultKey::Validation, "DID_NOT_SATISFY_ASSERTS"); diff --git a/Tensile/Source/lib/CMakeLists.txt b/Tensile/Source/lib/CMakeLists.txt index 64af4a88e..f8cc527c8 100644 --- a/Tensile/Source/lib/CMakeLists.txt +++ b/Tensile/Source/lib/CMakeLists.txt @@ -143,6 +143,7 @@ if(TENSILE_USE_HIP) if(WIN32) target_compile_options( TensileHost PUBLIC -Wno-deprecated-declarations -Wno-ignored-attributes -Wdll-attribute-on-redeclaration -fdelayed-template-parsing ) + target_link_libraries( TensileHost PUBLIC Shlwapi ) endif() endif() diff --git a/Tensile/Source/lib/include/Tensile/AMDGPU.hpp b/Tensile/Source/lib/include/Tensile/AMDGPU.hpp index e87a4d005..4c425e593 100644 --- a/Tensile/Source/lib/include/Tensile/AMDGPU.hpp +++ b/Tensile/Source/lib/include/Tensile/AMDGPU.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -190,8 +190,11 @@ namespace Tensile } AMDGPU(); - AMDGPU(Processor p, int computeUnitCount, std::string const& deviceName); - AMDGPU(std::string const& archName, int computeUnitCount, std::string const& deviceName); + AMDGPU(Processor p, int computeUnitCount, int isAPU, std::string const& deviceName); + AMDGPU(std::string const& archName, + int computeUnitCount, + int isAPU, + std::string const& deviceName); ~AMDGPU(); @@ -199,6 +202,7 @@ namespace Tensile int wavefrontSize = 64; int simdPerCu = 4; int computeUnitCount = 0; + int isAPU = 0; std::string deviceName; virtual bool runsKernelTargeting(Processor p) const; diff --git a/Tensile/Source/lib/include/Tensile/AMDGPUPredicates.hpp b/Tensile/Source/lib/include/Tensile/AMDGPUPredicates.hpp index c8ebf2db5..c065af7ed 100644 --- a/Tensile/Source/lib/include/Tensile/AMDGPUPredicates.hpp +++ b/Tensile/Source/lib/include/Tensile/AMDGPUPredicates.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -96,6 +96,32 @@ namespace Tensile } }; + struct IsAPU : public Predicate_CRTP + { + enum + { + HasIndex = false, + HasValue = true + }; + int value; + + IsAPU() = default; + IsAPU(int val) + : value(val) + { + } + + static std::string Type() + { + return "IsAPU"; + } + + virtual bool operator()(AMDGPU const& gpu) const + { + return gpu.isAPU == value; + } + }; + struct RunsKernelTargeting : public Predicate_CRTP { enum diff --git a/Tensile/Source/lib/include/Tensile/ContractionProblem.hpp b/Tensile/Source/lib/include/Tensile/ContractionProblem.hpp index 45b0d6002..01b2a049f 100644 --- a/Tensile/Source/lib/include/Tensile/ContractionProblem.hpp +++ b/Tensile/Source/lib/include/Tensile/ContractionProblem.hpp @@ -810,14 +810,24 @@ namespace Tensile return m_workspaceSize; } + size_t getNumTiles(SizeMapping const& sizeMapping) const; + size_t getItersPerTile(SizeMapping const& sizeMapping) const; + void checkPersistentKernelEligibility(ContractionSolution const& solution, Hardware const& hardware); + void checkRequiredWorkspaceSize(ContractionSolution const& solution, + Hardware const& hardware); bool getPersistentKernelEligibility() const { return m_eligibleForPK; } + size_t getRequiredWorkspaceSize() const + { + return m_requiredWorkspaceSize; + } + private: TensorDescriptor m_a; TensorDescriptor m_b; @@ -845,6 +855,7 @@ namespace Tensile bool m_fp16AltImpl = false; bool m_fp16AltImplRound = false; bool m_stochasticRounding = false; + size_t m_requiredWorkspaceSize = 0; DataType m_f32XdlMathOp = DataType::Float; ArithmeticUnit m_arithmeticUnit = ArithmeticUnit::Any; KernelLanguage m_kernelLanguage = KernelLanguage::Any; diff --git a/Tensile/Source/lib/include/Tensile/ContractionProblemPredicates.hpp b/Tensile/Source/lib/include/Tensile/ContractionProblemPredicates.hpp index 50f034107..140341a8f 100644 --- a/Tensile/Source/lib/include/Tensile/ContractionProblemPredicates.hpp +++ b/Tensile/Source/lib/include/Tensile/ContractionProblemPredicates.hpp @@ -1183,7 +1183,7 @@ namespace Tensile virtual bool operator()(ContractionProblem const& problem) const override { - return problem.d().totalLogicalElements() * value <= problem.workspaceSize(); + return problem.getRequiredWorkspaceSize() <= problem.workspaceSize(); } virtual bool debugEval(ContractionProblem const& problem, @@ -1191,7 +1191,7 @@ namespace Tensile { bool rv = (*this)(problem); - stream << *this << ": (" << problem.d().totalLogicalElements() << " * " << value + stream << *this << ": (" << problem.getRequiredWorkspaceSize() << " <= " << problem.workspaceSize() << ") == " << rv; return rv; diff --git a/Tensile/Source/lib/include/Tensile/ContractionSolution.hpp b/Tensile/Source/lib/include/Tensile/ContractionSolution.hpp index 0768f071c..c04a8a27a 100644 --- a/Tensile/Source/lib/include/Tensile/ContractionSolution.hpp +++ b/Tensile/Source/lib/include/Tensile/ContractionSolution.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -65,6 +65,32 @@ namespace Tensile size_t depthUorMT1; }; + struct SizeMapping + { + dim3 workGroupSize; + dim3 threadTile; + dim3 macroTile; + + size_t staggerU = 0; + size_t depthU = 0; + size_t globalSplitU = 0; + size_t staggerStrideShift = 0; + int workGroupMapping = 0; + + size_t packBatchDims = 0; + int packSummationDims = 0; + int magicDivAlg = 1; + int streamK = 0; + int persistentKernel = 0; + bool persistentKernelAlongBatch = false; + + bool sourceKernel = false; + int globalAccumulation = 0; + size_t workspaceSizePerElemC = 0; + }; + + std::ostream& operator<<(std::ostream& stream, const SizeMapping& sizeMapping); + /** * Represents a single kernel or set of kernels that can perform a single * tensor contraction. @@ -91,6 +117,7 @@ namespace Tensile { return kernelName; } + virtual std::string name() const { return kernelName; @@ -100,6 +127,9 @@ namespace Tensile return kernelName; } + bool getMatrixInstructionFromKernelName(vector4& miInst) const; + bool getGSUAlgorithmFromKernelName(std::string& gsuAlg) const; + bool isSourceKernel() const; //! Estimates based on problem size, solution tile, and machine hardware @@ -169,9 +199,10 @@ namespace Tensile }; /** - * Calculate required workspace size. - */ - size_t requiredWorkspaceSize(Problem const& problem) const; + * Calculate required workspace size. + */ + size_t requiredWorkspaceSize(Problem const& problem, Hardware const& hardware) const; + size_t partialTileSize(size_t skGrid) const; static float computeGranularity(float x); Granularities computeGranularities( @@ -222,6 +253,16 @@ namespace Tensile TypedInputs const& inputs, Hardware const& hardware) const; + template + KernelInvocation generateStreamKInitCall(Problem const& problem, + TypedInputs const& inputs, + Hardware const& hardware) const; + + template + std::string streamKInitKernelName(Problem const& problem, + TypedInputs const& inputs, + Hardware const& hardware) const; + template KernelInvocation generateBetaOnlyCall(Problem const& problem, TypedInputs const& inputs, @@ -240,35 +281,16 @@ namespace Tensile template std::string outputConversionKernelName(Problem const& problem, TypedInputs const& inputs, + int gsu, + int vw, + int reduction, + int gsuUnrollUnit, Hardware const& hardware) const; bool canSolve(Problem const& problem, Hardware const& hardware) const; bool matchesProblemType(Problem const& problem, Hardware const& hardware) const; - struct SizeMapping - { - dim3 workGroupSize; - dim3 threadTile; - dim3 macroTile; - - size_t staggerU = 0; - size_t depthU = 0; - size_t globalSplitU = 0; - size_t staggerStrideShift = 0; - int workGroupMapping = 0; - - size_t packBatchDims = 0; - int packSummationDims = 0; - int magicDivAlg = 1; - int persistentKernel = 0; - bool persistentKernelAlongBatch = false; - - bool sourceKernel = false; - int globalAccumulation = 0; - size_t workspaceSizePerElemC = 0; - }; - struct ProblemType { std::string operationIdentifier; diff --git a/Tensile/Source/lib/include/Tensile/ContractionSolution_fwd.hpp b/Tensile/Source/lib/include/Tensile/ContractionSolution_fwd.hpp index bec710c0f..0b6d672fd 100644 --- a/Tensile/Source/lib/include/Tensile/ContractionSolution_fwd.hpp +++ b/Tensile/Source/lib/include/Tensile/ContractionSolution_fwd.hpp @@ -29,4 +29,5 @@ namespace Tensile { class ContractionSolution; + struct SizeMapping; } diff --git a/Tensile/Source/lib/include/Tensile/Debug.hpp b/Tensile/Source/lib/include/Tensile/Debug.hpp index 28b8f5825..364b5c429 100644 --- a/Tensile/Source/lib/include/Tensile/Debug.hpp +++ b/Tensile/Source/lib/include/Tensile/Debug.hpp @@ -66,6 +66,8 @@ namespace Tensile bool printWinningKernelName() const; + bool printKernelCommonParams() const; + bool printSolutionSelectionTime() const; bool printLibraryLogicIndex() const; diff --git a/Tensile/Source/lib/include/Tensile/PlaceholderLibrary.hpp b/Tensile/Source/lib/include/Tensile/PlaceholderLibrary.hpp index 2e4c00e33..10898ec2d 100644 --- a/Tensile/Source/lib/include/Tensile/PlaceholderLibrary.hpp +++ b/Tensile/Source/lib/include/Tensile/PlaceholderLibrary.hpp @@ -67,45 +67,45 @@ namespace Tensile switch(condition) { case LazyLoadingInit::All: - return "TensileLibrary_.*"; + return "TensileLibrary_*"; case LazyLoadingInit::gfx803: - return "TensileLibrary_*_gfx803.*"; + return "TensileLibrary_*_gfx803"; case LazyLoadingInit::gfx900: - return "TensileLibrary_*_gfx900.*"; + return "TensileLibrary_*_gfx900"; case LazyLoadingInit::gfx906: - return "TensileLibrary_*_gfx906.*"; + return "TensileLibrary_*_gfx906"; case LazyLoadingInit::gfx908: - return "TensileLibrary_*_gfx908.*"; + return "TensileLibrary_*_gfx908"; case LazyLoadingInit::gfx90a: - return "TensileLibrary_*_gfx90a.*"; + return "TensileLibrary_*_gfx90a"; case LazyLoadingInit::gfx940: - return "TensileLibrary_*_gfx940.*"; + return "TensileLibrary_*_gfx940"; case LazyLoadingInit::gfx941: - return "TensileLibrary_*_gfx941.*"; + return "TensileLibrary_*_gfx941"; case LazyLoadingInit::gfx942: - return "TensileLibrary_*_gfx942.*"; + return "TensileLibrary_*_gfx942"; case LazyLoadingInit::gfx1010: - return "TensileLibrary_*_gfx1010.*"; + return "TensileLibrary_*_gfx1010"; case LazyLoadingInit::gfx1011: - return "TensileLibrary_*_gfx1011.*"; + return "TensileLibrary_*_gfx1011"; case LazyLoadingInit::gfx1012: - return "TensileLibrary_*_gfx1012.*"; + return "TensileLibrary_*_gfx1012"; case LazyLoadingInit::gfx1030: - return "TensileLibrary_*_gfx1030.*"; + return "TensileLibrary_*_gfx1030"; case LazyLoadingInit::gfx1031: - return "TensileLibrary_*_gfx1031.*"; + return "TensileLibrary_*_gfx1031"; case LazyLoadingInit::gfx1032: - return "TensileLibrary_*_gfx1032.*"; + return "TensileLibrary_*_gfx1032"; case LazyLoadingInit::gfx1034: - return "TensileLibrary_*_gfx1034.*"; + return "TensileLibrary_*_gfx1034"; case LazyLoadingInit::gfx1035: - return "TensileLibrary_*_gfx1035.*"; + return "TensileLibrary_*_gfx1035"; case LazyLoadingInit::gfx1100: - return "TensileLibrary_*_gfx1100.*"; + return "TensileLibrary_*_gfx1100"; case LazyLoadingInit::gfx1101: - return "TensileLibrary_*_gfx1101.*"; + return "TensileLibrary_*_gfx1101"; case LazyLoadingInit::gfx1102: - return "TensileLibrary_*_gfx1102.*"; + return "TensileLibrary_*_gfx1102"; case LazyLoadingInit::None: return ""; } diff --git a/Tensile/Source/lib/include/Tensile/Serialization/ContractionSolution.hpp b/Tensile/Source/lib/include/Tensile/Serialization/ContractionSolution.hpp index faba922bb..def53f5a0 100644 --- a/Tensile/Source/lib/include/Tensile/Serialization/ContractionSolution.hpp +++ b/Tensile/Source/lib/include/Tensile/Serialization/ContractionSolution.hpp @@ -71,10 +71,10 @@ namespace Tensile }; template - struct MappingTraits + struct MappingTraits { using iot = IOTraits; - static void mapping(IO& io, ContractionSolution::SizeMapping& s) + static void mapping(IO& io, SizeMapping& s) { iot::mapRequired(io, "workGroup", s.workGroupSize); iot::mapRequired(io, "threadTile", s.threadTile); @@ -89,6 +89,7 @@ namespace Tensile iot::mapOptional(io, "packBatchDims", s.packBatchDims); iot::mapOptional(io, "packSummationDims", s.packSummationDims); iot::mapOptional(io, "magicDivAlg", s.magicDivAlg); + iot::mapOptional(io, "streamK", s.streamK); iot::mapRequired(io, "persistentKernel", s.persistentKernel); iot::mapRequired(io, "persistentKernelAlongBatch", s.persistentKernelAlongBatch); iot::mapRequired(io, "sourceKernel", s.sourceKernel); diff --git a/Tensile/Source/lib/include/Tensile/Serialization/PlaceholderLibrary.hpp b/Tensile/Source/lib/include/Tensile/Serialization/PlaceholderLibrary.hpp index 2b090ec72..281875ea0 100644 --- a/Tensile/Source/lib/include/Tensile/Serialization/PlaceholderLibrary.hpp +++ b/Tensile/Source/lib/include/Tensile/Serialization/PlaceholderLibrary.hpp @@ -31,7 +31,12 @@ #include #include -#include +//Replace std::regex, as it crashes when matching long lines(GCC Bug #86164). +#ifdef WIN32 +#include "shlwapi.h" +#else +#include +#endif namespace Tensile { @@ -68,7 +73,11 @@ namespace Tensile for(auto condition : ctx->preloaded) { std::string pattern = RegexPattern(condition); - if(std::regex_search(lib.filePrefix, std::regex(pattern))) +#ifdef WIN32 + if(PathMatchSpecA(lib.filePrefix.c_str(), pattern.c_str())) +#else + if(fnmatch(pattern.c_str(), lib.filePrefix.c_str(), 0) == 0) +#endif { lib.loadPlaceholderLibrary(); break; diff --git a/Tensile/Source/lib/include/Tensile/Serialization/Predicates.hpp b/Tensile/Source/lib/include/Tensile/Serialization/Predicates.hpp index b7450dcef..8795fd752 100644 --- a/Tensile/Source/lib/include/Tensile/Serialization/Predicates.hpp +++ b/Tensile/Source/lib/include/Tensile/Serialization/Predicates.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -165,6 +165,7 @@ namespace Tensile { SubclassMap rv({Base::template Pair(), Base::template Pair(), + Base::template Pair(), Base::template Pair()}); auto gmap = Generic::GetSubclasses(); @@ -193,6 +194,12 @@ namespace Tensile { }; + template + struct MappingTraits + : public AutoMappingTraits + { + }; + template struct MappingTraits : public AutoMappingTraits diff --git a/Tensile/Source/lib/source/AMDGPU.cpp b/Tensile/Source/lib/source/AMDGPU.cpp index 25f0b7b4a..ba50c9a4f 100644 --- a/Tensile/Source/lib/source/AMDGPU.cpp +++ b/Tensile/Source/lib/source/AMDGPU.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -35,16 +35,19 @@ namespace Tensile TENSILE_API AMDGPU::AMDGPU() {} - TENSILE_API AMDGPU::AMDGPU(AMDGPU::Processor p, int cus, std::string const& name) + TENSILE_API AMDGPU::AMDGPU(AMDGPU::Processor p, int cus, int apu, std::string const& name) : processor(p) , computeUnitCount(cus) + , isAPU(apu) , deviceName(name) { } - TENSILE_API AMDGPU::AMDGPU(std::string const& archName, int cus, std::string const& name) + TENSILE_API + AMDGPU::AMDGPU(std::string const& archName, int cus, int apu, std::string const& name) : processor(toProcessorId(archName)) , computeUnitCount(cus) + , isAPU(apu) , deviceName(name) { } diff --git a/Tensile/Source/lib/source/ContractionProblem.cpp b/Tensile/Source/lib/source/ContractionProblem.cpp index f2b3f9236..20724e037 100644 --- a/Tensile/Source/lib/source/ContractionProblem.cpp +++ b/Tensile/Source/lib/source/ContractionProblem.cpp @@ -754,16 +754,8 @@ namespace Tensile m_bZeroPads.push_back(m_boundIndices[toBoundsPos(zp.boundIndex)].bZeroPad); } - void ContractionProblem::checkPersistentKernelEligibility(ContractionSolution const& solution, - Hardware const& hardware) + size_t ContractionProblem::getNumTiles(SizeMapping const& sizeMapping) const { - m_eligibleForPK = true; - - // Get the new WorkGroup numbers under the PK and CU value - auto sizeMapping = solution.sizeMapping; - if(sizeMapping.persistentKernel == 0) - return; - // Get the normal WorkGroup numbers by sizeMapping MacroTile dim3 numWG(1, 1, 1); for(size_t i = 0; i < m_freeIndicesA.size(); i++) @@ -789,9 +781,38 @@ namespace Tensile numWG.y *= sizeMapping.globalSplitU; size_t problemTiles = numWG.x * numWG.y; + // if(sizeMapping.persistentKernelAlongBatch || sizeMapping.streamK != 0) if(sizeMapping.persistentKernelAlongBatch) problemTiles *= numWG.z; + return problemTiles; + } + + size_t ContractionProblem::getItersPerTile(SizeMapping const& sizeMapping) const + { + size_t boundSize = 1; + for(size_t i = 0; i < m_boundIndices.size(); ++i) + { + boundSize *= m_boundSizes[i]; + } + + size_t itersPerTile = CeilDivide(boundSize, sizeMapping.depthU); + + return itersPerTile; + } + + void ContractionProblem::checkPersistentKernelEligibility(ContractionSolution const& solution, + Hardware const& hardware) + { + m_eligibleForPK = true; + + // Get the new WorkGroup numbers under the PK and CU value + auto sizeMapping = solution.sizeMapping; + if(sizeMapping.persistentKernel == 0) + return; + + auto problemTiles = getNumTiles(sizeMapping); + AMDGPU const* pAMDGPU = dynamic_cast(&hardware); assert(pAMDGPU != nullptr && pAMDGPU->computeUnitCount != 0); @@ -812,6 +833,12 @@ namespace Tensile m_eligibleForPK = persistentGroups < problemTiles; } + void ContractionProblem::checkRequiredWorkspaceSize(ContractionSolution const& solution, + Hardware const& hardware) + { + m_requiredWorkspaceSize = solution.requiredWorkspaceSize(*this, hardware); + } + void ContractionProblem::normalize() { m_maxProblemSize = 0; diff --git a/Tensile/Source/lib/source/ContractionSolution.cpp b/Tensile/Source/lib/source/ContractionSolution.cpp index 9168a38c3..e9fb8590a 100644 --- a/Tensile/Source/lib/source/ContractionSolution.cpp +++ b/Tensile/Source/lib/source/ContractionSolution.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -35,6 +35,11 @@ #include #include +#include +#include + +#define TENSILE_STREAMK_GRID 1 + namespace Tensile { PerfModel perf; @@ -304,6 +309,8 @@ namespace Tensile rv.numWorkGroups.x = CeilDivide(rv.numWorkGroups.x, sizeMapping.macroTile.x); rv.numWorkGroups.y = CeilDivide(rv.numWorkGroups.y, sizeMapping.macroTile.y); + auto numTiles = rv.numWorkGroups; + uint32_t problemNumGroupTiles0 = rv.numWorkGroups.x; uint32_t problemNumGroupTiles1 = rv.numWorkGroups.y; // used only when persistent kernel along batch @@ -311,12 +318,26 @@ namespace Tensile rv.numWorkGroups.y *= sizeMapping.globalSplitU; - if(sizeMapping.persistentKernel != 0) + size_t cuCount = 0; + if(sizeMapping.streamK != 0 || sizeMapping.persistentKernel != 0) { AMDGPU const* pAMDGPU = dynamic_cast(&hardware); assert(pAMDGPU != nullptr && pAMDGPU->computeUnitCount != 0); + cuCount = pAMDGPU->computeUnitCount; + } - size_t cuCount = pAMDGPU->computeUnitCount; + size_t skGrid = 0; + if(sizeMapping.streamK != 0) + { + skGrid = cuCount * TENSILE_STREAMK_GRID; + rv.numWorkGroups.x = skGrid; + rv.numWorkGroups.y = 1; + if(sizeMapping.persistentKernelAlongBatch) + rv.numWorkGroups.z = 1; + } + + if(sizeMapping.persistentKernel != 0) + { size_t finalPKValue = sizeMapping.persistentKernel; size_t problemGroups = rv.numWorkGroups.x * rv.numWorkGroups.y; if(sizeMapping.persistentKernelAlongBatch) @@ -366,7 +387,7 @@ namespace Tensile rv.args.append("tensor2dSizeB", tensor2dSizeB); } - if(sizeMapping.globalAccumulation) + if(sizeMapping.globalAccumulation && sizeMapping.streamK < 2) { rv.args.append("ws_d", inputs.ws); rv.args.append("ws_c", inputs.ws); @@ -393,6 +414,16 @@ namespace Tensile rv.args.append("batchB", inputs.batchB); } + if(sizeMapping.streamK >= 2) + { + // StreamK workspace + flags + rv.args.append("ws", inputs.ws); + void* ws = inputs.ws; + size_t flagsOffset = partialTileSize(skGrid); + void* flags = (void*)(static_cast(ws) + flagsOffset); + rv.args.append("Flags", flags); + } + rv.args.append("offsetD", d.offset()); rv.args.append("offsetC", c.offset()); rv.args.append("offsetA", a.offset()); @@ -412,7 +443,7 @@ namespace Tensile size_t startStrideCD = problemType.useInitialStridesCD ? 0 : 1; size_t startStrideAB = problemType.useInitialStridesAB ? 0 : 1; - if(sizeMapping.globalAccumulation) + if(sizeMapping.globalAccumulation && sizeMapping.streamK < 2) { size_t wsStride = startStrideCD ? d.sizes()[0] : 1; for(size_t i = startStrideCD; i < d.dimensions(); i++) @@ -551,15 +582,20 @@ namespace Tensile uint32_t magicNumberWgmRemainder1 = 0; // conditional args, aligned with KernelWriterAssembly.py - if(sizeMapping.persistentKernel != 0) + if(sizeMapping.persistentKernel != 0 || sizeMapping.streamK != 0) { uint32_t magicShift; rv.args.append("magicNumberProblemNumGroupTiles0", magicNumber(2, problemNumGroupTiles0, &magicShift)); rv.args.append("magicShiftProblemNumGroupTiles0", magicShift); + } + + if(sizeMapping.persistentKernel != 0) + { rv.args.append("gridNumWorkGroups0", rv.numWorkGroups.x); } + // if(sizeMapping.persistentKernelAlongBatch || sizeMapping.streamK != 0) if(sizeMapping.persistentKernelAlongBatch) { uint32_t numGroupTiles0x1 = problemNumGroupTiles0 * problemNumGroupTiles1; @@ -571,6 +607,46 @@ namespace Tensile rv.args.append("magicShiftProblemNumGroupTiles0By1", magicShift); } + if(sizeMapping.streamK != 0) + { + auto itersPerTile = problem.getItersPerTile(sizeMapping); + auto tiles = problem.getNumTiles(sizeMapping); + auto totalIters = tiles * itersPerTile; + uint32_t magicNumberItersPerTile; + uint32_t magicShiftItersPerTile; + magicNumberItersPerTile = magicNumber(2, itersPerTile, &magicShiftItersPerTile); + + rv.args.append("itersPerTile", itersPerTile); + rv.args.append("magicNumberItersPerTile", magicNumberItersPerTile); + rv.args.append("magicShiftItersPerTile", magicShiftItersPerTile); + rv.args.append("totalIters", totalIters); + if(sizeMapping.streamK < 3) // Basic SK + { + uint32_t itersPerWave = CeilDivide(totalIters, rv.numWorkGroups.x); + rv.args.append("SKItersPerWG", itersPerWave); + } + else if(sizeMapping.streamK == 3) // Two-tile SK + { + uint32_t numOutputTiles = tiles; + bool bigEnough = numOutputTiles > skGrid; + // skTiles is number of Stream-K tiles to complete + // Two-tile algorithm causes each WG to run an even number of Stream-K iterations, + // followed by an even number of data-parllel tiles + uint32_t skTiles + = bigEnough ? skGrid + numOutputTiles % skGrid : numOutputTiles; + // Number of data-parallel tiles on each workgroup would be: + // dpTilesPerWG = bigEnough ? (numOutputTiles - skTiles) / skGrid : 0; + + uint32_t skItersPerWG = skTiles * itersPerTile / skGrid; + uint32_t skExtraIters = skTiles * itersPerTile % (skGrid); + + rv.args.append("SKItersPerWG", skItersPerWG); + rv.args.append("skGrid", skGrid); + rv.args.append("skTiles", skTiles); + rv.args.append("skExtraIters", skExtraIters); + } + } + if(sizeMapping.workGroupMapping != 0) { numFullBlocks = problemNumGroupTiles1 / sizeMapping.workGroupMapping; @@ -611,6 +687,66 @@ namespace Tensile return sizeMapping.sourceKernel; } + template + KernelInvocation ContractionSolution::generateStreamKInitCall(Problem const& problem, + TypedInputs const& inputs, + Hardware const& hardware) const + { + TensorDescriptor const& c = problem.c(); + TensorDescriptor const& d = problem.d(); + + KernelInvocation rv; + + rv.args = KernelArguments(T_Debug); + + rv.args.reserve(512, 64); + + rv.kernelName = streamKInitKernelName(problem, inputs, hardware); + + rv.workGroupSize.x = 256; + rv.workGroupSize.y = 1; + rv.workGroupSize.z = 1; + + AMDGPU const* pAMDGPU = dynamic_cast(&hardware); + assert(pAMDGPU != nullptr && pAMDGPU->computeUnitCount != 0); + size_t cuCount = pAMDGPU->computeUnitCount; + size_t skGrid = cuCount * TENSILE_STREAMK_GRID; + size_t wiZ = 1; + for(size_t i = 0; i < problem.batchIndices().size(); i++) + wiZ *= problem.batchSize(i); + size_t flagCount = skGrid * wiZ; + + rv.numWorkGroups.x = CeilDivide(flagCount, rv.workGroupSize.x); + rv.numWorkGroups.y = 1; + rv.numWorkGroups.z = 1; + + rv.numWorkItems.x = rv.workGroupSize.x * rv.numWorkGroups.x; + rv.numWorkItems.y = rv.workGroupSize.y * rv.numWorkGroups.y; + rv.numWorkItems.z = rv.workGroupSize.z * rv.numWorkGroups.z; + + void* ws = inputs.ws; + size_t flagsOffset = partialTileSize(skGrid); + void* flags = (void*)(static_cast(ws) + flagsOffset); + rv.args.append("Flags", flags); + + rv.args.append("flagCount", flagCount); + + //Pass along code object dependency + // TODO check this + rv.codeObjectFile = codeObjectFilename.load(); + + return rv; + } + + template + std::string ContractionSolution::streamKInitKernelName(Problem const& problem, + TypedInputs const& inputs, + Hardware const& hardware) const + { + std::string name = "WSFlags"; + return name; + } + template KernelInvocation ContractionSolution::generateBetaOnlyCall(Problem const& problem, TypedInputs const& inputs, @@ -725,6 +861,10 @@ namespace Tensile KernelInvocation ContractionSolution::generateOutputConversionCall( Problem const& problem, TypedInputs const& inputs, Hardware const& hardware) const { + AMDGPU const* pAMDGPU = dynamic_cast(&hardware); + assert(pAMDGPU != nullptr && pAMDGPU->computeUnitCount != 0); + size_t cuCount = pAMDGPU->computeUnitCount; + TensorDescriptor const& c = problem.c(); TensorDescriptor const& d = problem.d(); @@ -734,8 +874,6 @@ namespace Tensile rv.args.reserve(512, 64); - rv.kernelName = outputConversionKernelName(problem, inputs, hardware); - rv.workGroupSize.x = 256; rv.workGroupSize.y = 1; rv.workGroupSize.z = 1; @@ -750,7 +888,64 @@ namespace Tensile for(size_t i = 0; i < problem.batchIndices().size(); i++) wiZ *= problem.batchSize(i); - rv.numWorkGroups.x = CeilDivide(wiX * wiY * wiZ, rv.workGroupSize.x); + const unsigned int numThreadsPerCU = 256; + unsigned int gsu = static_cast(sizeMapping.globalSplitU); + if(sizeMapping.globalAccumulation == 1) + // globalAccumulation = 1 case, ignore globalSplitU and use 1 + gsu = 1; + // wider global load for postGSU + // only for compute type = Float + bool supportedTypeForVWopt + = problem.alphaType() == DataType::Float || problem.alphaType() == DataType::Double; + size_t total = wiX * wiY * wiZ; + int vw = 1; + size_t threshVW2 + = cuCount * numThreadsPerCU * 2; // should be more than number of physical threads * vw + if(supportedTypeForVWopt && total > threshVW2 && problem.freeSizeA(0) % 2 == 0) + vw = 2; + + // parallel reduction width calculation + // only for compute type = Float + bool supportedTypeForReductionOpt + = problem.alphaType() == DataType::Float || problem.alphaType() == DataType::Double; + size_t threshReduction + = cuCount + * numThreadsPerCU; // should be less than number of physical threads / reduction + const unsigned int maxReductionConst = 4; + const unsigned int minGSUperReduction + = 32; // Minimum GSU=128 for Reduction=4, GSU=64 for Reduction2 + unsigned int maxReduction = std::min(maxReductionConst, gsu / minGSUperReduction); + unsigned int reduction = 1; + if(supportedTypeForReductionOpt && (gsu & (gsu - 1)) == 0 && maxReduction > 1) + { + // apply reduction only if float compute type and gsu is power of 2 (for small array only) + reduction = maxReduction; + while(reduction > 1) + { + size_t totalThread = total * reduction / vw; + if(gsu % reduction == 0 && totalThread <= threshReduction) + // found an applicable reduction + break; + // for next loop + reduction /= 2; + } + if(reduction == 2 && problem.alphaType() == DataType::Float) + // so far, reduction=2 does not perform well + reduction = 1; + } + + // GSU loop unroll opt + // only for compute type = Float, Double + bool supportedTypeForUnrollOpt + = problem.alphaType() == DataType::Float || problem.alphaType() == DataType::Double; + int gsuUnrollUnit = 16 * reduction; // must match Tensile generator code + if(supportedTypeForUnrollOpt == false || sizeMapping.globalAccumulation == 1) + gsuUnrollUnit = 1; + + rv.kernelName = outputConversionKernelName( + problem, inputs, gsu, vw, reduction, gsuUnrollUnit, hardware); + + rv.numWorkGroups.x = CeilDivide(wiX * wiY * wiZ * reduction, rv.workGroupSize.x * vw); rv.numWorkGroups.y = 1; rv.numWorkGroups.z = 1; @@ -803,10 +998,7 @@ namespace Tensile idx++; } - if(sizeMapping.globalAccumulation == 1) - rv.args.append("gsu", 1); - else - rv.args.append("gsu", sizeMapping.globalSplitU); + rv.args.append("gsu", gsu); if(problemType.stochasticRounding) { @@ -849,6 +1041,10 @@ namespace Tensile template std::string ContractionSolution::outputConversionKernelName(Problem const& problem, TypedInputs const& inputs, + int gsu, + int vw, + int reduction, + int gsuUnrollUnit, Hardware const& hardware) const { std::string name = concatenate( @@ -861,6 +1057,27 @@ namespace Tensile name += "_PostGSU"; + // add extra string for gsu + // This part must match tensile code generation (in KernelWriterConversion.py) + // add mod related str (only for gsuUnrollUnit > 1) + if(gsuUnrollUnit > 1) + { + size_t gsuMod = gsu % gsuUnrollUnit; + if(gsuMod == 0) + gsuMod = gsuUnrollUnit; + std::string modStr = ""; + if(gsu > gsuUnrollUnit) + { + modStr = "_mod" + std::to_string(gsuUnrollUnit); + } + name += std::to_string(gsuMod) + modStr; + } + if(vw >= 2) + name += "_VW" + std::to_string(vw); + + if(reduction >= 2) + name += "_R" + std::to_string(reduction); + return name; } @@ -917,7 +1134,17 @@ namespace Tensile std::vector rv; - if(sizeMapping.globalSplitU > 1 && sizeMapping.globalAccumulation != 2) + if(sizeMapping.streamK >= 2) + { + if(debug) + rv.push_back(generateStreamKInitCall(problem, inputs, hardware)); + else + rv.push_back( + generateStreamKInitCall(problem, inputs, hardware)); + } + + if(sizeMapping.streamK == 1 + || (sizeMapping.globalSplitU > 1 && sizeMapping.globalAccumulation != 2)) { if(debug) rv.push_back(generateBetaOnlyCall(problem, inputs, hardware)); @@ -930,8 +1157,9 @@ namespace Tensile else rv.push_back(generateSingleCall(problem, inputs, hardware)); - if(sizeMapping.globalAccumulation) + if(sizeMapping.globalAccumulation && !sizeMapping.streamK) { + // TODO Streamk May need conversion call for HPA?? if(debug) rv.push_back( generateOutputConversionCall(problem, inputs, hardware)); @@ -943,6 +1171,55 @@ namespace Tensile return rv; } + bool ContractionSolution::getMatrixInstructionFromKernelName( + vector4& matrixInst) const + { + + std::string regexp_string("_MI(\\d+)x(\\d+)x(\\d+)x(\\d+)_"); + std::regex miRegex(regexp_string); + std::smatch matches; + + std::string kName = this->KernelName(); + + bool matched + = std::regex_search(kName, matches, miRegex) + && (matches.size() + == 5 /* 1 (index 0) for entire match and 4 for the sub-experssion matches */); + + if(matched) + { + // the first sub_match element (index 0) corresponds to the entire match + matrixInst.x = atoi(matches[1].str().c_str()); + matrixInst.y = atoi(matches[2].str().c_str()); + matrixInst.z = atoi(matches[3].str().c_str()); + matrixInst.w = atoi(matches[4].str().c_str()); + } + + return matched; + } + + bool ContractionSolution::getGSUAlgorithmFromKernelName(std::string& gsuAlg) const + { + std::string regexp_string("_GSUA([S|M])B_"); + std::regex miRegex(regexp_string); + std::smatch matches; + + std::string kName = this->KernelName(); + + bool matched + = std::regex_search(kName, matches, miRegex) + && (matches.size() + == 2 /* 1 (index 0) for entire match and 1 for the sub-experssion matches */); + + if(matched) + { + // the first sub_match element (index 0) corresponds to the entire match + gsuAlg = (matches[1].compare("S") == 0) ? "SingleBuffer" : "MultipleBuffer"; + } + + return matched; + } + std::vector ContractionSolution::solve(ContractionSolution::Problem const& problem, ContractionSolution::Inputs const& inputs, @@ -951,6 +1228,22 @@ namespace Tensile if(Debug::Instance().printWinningKernelName()) std::cout << "Running kernel: " << this->KernelName() << std::endl; + if(Debug::Instance().printKernelCommonParams()) + { + std::cout << "Kernel name: " << this->KernelName() << std::endl; + std::cout << "Kernel parameters: " << std::endl; + + vector4 matrixInst; + if(this->getMatrixInstructionFromKernelName(matrixInst)) + std::cout << std::right << std::setw(30) << "MatrixInstruction: " << matrixInst + << std::endl; + + std::string GSUAlg; + if(this->getGSUAlgorithmFromKernelName(GSUAlg)) + std::cout << std::right << std::setw(30) << "GSUAlgorithm: " << GSUAlg << std::endl; + + std::cout << this->sizeMapping << std::endl; + } // retreive alpha/beta type set via setAlpha/BetaType() auto alphaType = problem.alphaType(); auto betaType = problem.betaType(); @@ -1174,11 +1467,39 @@ namespace Tensile return spm; } - size_t ContractionSolution::requiredWorkspaceSize(Problem const& problem) const + size_t ContractionSolution::requiredWorkspaceSize(Problem const& problem, + Hardware const& hardware) const { size_t size = 0; - size += problem.d().totalLogicalElements() * sizeMapping.workspaceSizePerElemC; + if(sizeMapping.streamK >= 2) + { + AMDGPU const* pAMDGPU = dynamic_cast(&hardware); + assert(pAMDGPU != nullptr && pAMDGPU->computeUnitCount != 0); + size_t cuCount = pAMDGPU->computeUnitCount; + size_t skGrid = cuCount * TENSILE_STREAMK_GRID; + // Get space required for partial tiles + size += partialTileSize(skGrid); + // Add space for flags + // Flags for partial tiles - dword per flag for fast addressing and comparisons + size += skGrid * 4; + // size *= batches; // TODO need tile and flag per batch + } + else + size += problem.d().totalLogicalElements() * sizeMapping.workspaceSizePerElemC; + + return size; + } + + size_t ContractionSolution::partialTileSize(size_t skGrid) const + { + size_t size = 0; + + size_t tileSize + = sizeMapping.macroTile.x * sizeMapping.macroTile.y * sizeMapping.workspaceSizePerElemC; + size += tileSize * skGrid; // Partials tile per WG + // TODO batches + // TODO round up for alignment? return size; } @@ -1468,4 +1789,41 @@ namespace Tensile << " shiftPtrElemB=" << st.shiftPtrElemB << " depthUorMT0=" << st.depthUorMT0 << " depthUorMT1=" << st.depthUorMT1; } + + std::ostream& operator<<(std::ostream& stream, const SizeMapping& sizeMapping) + { + std::ios_base::fmtflags flags(stream.flags()); + + stream << std::right << std::setw(30) << "workGroupSize: " << sizeMapping.workGroupSize + << std::endl + << std::setw(30) << "threadTile: " << sizeMapping.threadTile << std::endl + << std::setw(30) << "macroTile: " << sizeMapping.macroTile << std::endl + + << std::setw(30) << "staggerU: " << sizeMapping.staggerU << std::endl + << std::setw(30) << "depthU: " << sizeMapping.depthU << std::endl + << std::setw(30) << "globalSplitU: " << sizeMapping.globalSplitU << std::endl + << std::setw(30) << "staggerStrideShift: " << sizeMapping.staggerStrideShift + << std::endl + << std::setw(30) << "workGroupMapping: " << sizeMapping.workGroupMapping << std::endl + + << std::setw(30) << "packBatchDims: " << sizeMapping.packBatchDims << std::endl + << std::setw(30) << "packSummationDims: " << sizeMapping.packSummationDims + << std::endl + << std::setw(30) << "magicDivAlg: " << sizeMapping.magicDivAlg << std::endl + << std::setw(30) << "streamK: " << sizeMapping.streamK << std::endl + << std::setw(30) << "persistentKernel: " << sizeMapping.persistentKernel << std::endl + << std::setw(30) + << "persistentKernelAlongBatch: " << sizeMapping.persistentKernelAlongBatch + << std::endl + + << std::setw(30) << "sourceKernel: " << sizeMapping.sourceKernel << std::endl + << std::setw(30) << "globalAccumulation: " << sizeMapping.globalAccumulation + << std::endl + << std::setw(30) << "workspaceSizePerElemC: " << sizeMapping.workspaceSizePerElemC + << std::endl; + + stream.flags(flags); + + return stream; + } } // namespace Tensile diff --git a/Tensile/Source/lib/source/Debug.cpp b/Tensile/Source/lib/source/Debug.cpp index 8d9832a5a..e37e8ae1f 100644 --- a/Tensile/Source/lib/source/Debug.cpp +++ b/Tensile/Source/lib/source/Debug.cpp @@ -113,6 +113,11 @@ namespace Tensile return m_value & 0x8000; } + bool Debug::printKernelCommonParams() const + { + return m_value & 0x80000; + } + bool Debug::printSolutionSelectionTime() const { return m_value & 0x10000; diff --git a/Tensile/Source/lib/source/hip/HipHardware.cpp b/Tensile/Source/lib/source/hip/HipHardware.cpp index ffa4f0fcd..411d7c575 100644 --- a/Tensile/Source/lib/source/hip/HipHardware.cpp +++ b/Tensile/Source/lib/source/hip/HipHardware.cpp @@ -33,8 +33,10 @@ namespace Tensile namespace hip { HipAMDGPU::HipAMDGPU(hipDeviceProp_t const& prop) - : AMDGPU( - std::string(prop.gcnArchName), prop.multiProcessorCount, std::string(prop.name)) + : AMDGPU(std::string(prop.gcnArchName), + prop.multiProcessorCount, + prop.directManagedMemAccessFromHost, + std::string(prop.name)) , properties(prop) { } @@ -63,6 +65,10 @@ namespace Tensile HIP_CHECK_EXC(hipDeviceGetAttribute(&prop.multiProcessorCount, hipDeviceAttributePhysicalMultiProcessorCount, deviceId)); + HIP_CHECK_EXC( + hipDeviceGetAttribute(&prop.directManagedMemAccessFromHost, + hipDeviceAttributeDirectManagedMemAccessFromHost, + deviceId)); } #endif diff --git a/Tensile/Source/lib/source/hip/HipSolutionAdapter.cpp b/Tensile/Source/lib/source/hip/HipSolutionAdapter.cpp index 3dff78dd9..d3aa6d7ce 100644 --- a/Tensile/Source/lib/source/hip/HipSolutionAdapter.cpp +++ b/Tensile/Source/lib/source/hip/HipSolutionAdapter.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -230,6 +230,10 @@ namespace Tensile { return err; } + else + { + (void)hipGetLastError(); // clear hipErrorNotFound + } } return err; @@ -342,6 +346,16 @@ namespace Tensile &argsSize, HIP_LAUNCH_PARAM_END}; + if(m_debug) + { + int numBlocks = 0; + int blockSize + = kernel.workGroupSize.x * kernel.workGroupSize.y * kernel.workGroupSize.z; + HIP_CHECK_RETURN(hipModuleOccupancyMaxActiveBlocksPerMultiprocessor( + &numBlocks, function, blockSize, 0)); + std::cout << "Occupancy = " << numBlocks << std::endl; + } + if(startEvent != nullptr) HIP_CHECK_RETURN(hipEventRecord(startEvent, stream)); HIP_CHECK_RETURN(hipExtModuleLaunchKernel(function, diff --git a/Tensile/Source/lib/source/ocl/OclHardware.cpp b/Tensile/Source/lib/source/ocl/OclHardware.cpp index f58249334..178611d3b 100644 --- a/Tensile/Source/lib/source/ocl/OclHardware.cpp +++ b/Tensile/Source/lib/source/ocl/OclHardware.cpp @@ -36,9 +36,11 @@ namespace Tensile namespace ocl { OclAMDGPU::OclAMDGPU(oclDeviceProp_t const& prop) - : AMDGPU(static_cast(prop.gcnArch), - prop.multiProcessorCount, - std::string(prop.name)) + : AMDGPU( + static_cast(prop.gcnArch), + prop.multiProcessorCount, + 0, // hipDeviceAttributeDirectManagedMemAccessFromHost not accessible through OCL interface + std::string(prop.name)) , properties(prop) { } diff --git a/Tensile/TensileCreateLibrary.py b/Tensile/TensileCreateLibrary.py index 9aa0f0258..69a53cc65 100644 --- a/Tensile/TensileCreateLibrary.py +++ b/Tensile/TensileCreateLibrary.py @@ -149,7 +149,10 @@ def getAssemblyCodeObjectFiles(kernels, kernelWriterAssembly, outputPath): return coFiles def which(p): - exes = [p+x for x in ['.bat', '', '.exe']] # bat may be front end for file with no extension + if os.name == "nt": + exes = [p+x for x in ['.bat', '', '.exe']] # bat may be front end for file with no extension + else: + exes = [p+x for x in ['', '.exe', '.bat']] system_path = os.environ['PATH'].split(os.pathsep) if p == 'hipcc' and 'CMAKE_CXX_COMPILER' in os.environ and os.path.isfile(os.environ['CMAKE_CXX_COMPILER']): return os.environ['CMAKE_CXX_COMPILER'] diff --git a/Tensile/Tests/emulation/mfma/1LDSB.yaml b/Tensile/Tests/emulation/mfma/1LDSB.yaml index a0a519d44..eb2c2c326 100644 --- a/Tensile/Tests/emulation/mfma/1LDSB.yaml +++ b/Tensile/Tests/emulation/mfma/1LDSB.yaml @@ -46,7 +46,8 @@ BenchmarkProblems: - ScheduleIterAlg: [3] - InnerUnroll: [1,2,4] - TransposeLDS: [0,1] - - LdsBlockSizePerPad: [-1] + - LdsBlockSizePerPadA: [-1] + - LdsBlockSizePerPadB: [-1] - LdsPadA: [-1] - LdsPadB: [-1] - 1LDSBuffer: [1] @@ -94,7 +95,8 @@ BenchmarkProblems: # - ScheduleIterAlg: [3] # - InnerUnroll: [1,2,4] # - TransposeLDS: [0,1] - # - LdsBlockSizePerPad: [-1] + # - LdsBlockSizePerPadA: [-1] + # - LdsBlockSizePerPadB: [-1] # - LdsPadA: [-1] # - LdsPadB: [-1] # - 1LDSBuffer: [1] @@ -141,7 +143,8 @@ BenchmarkProblems: # - ScheduleIterAlg: [3] # - InnerUnroll: [1,2] # - TransposeLDS: [0,1] - # - LdsBlockSizePerPad: [-1] + # - LdsBlockSizePerPadA: [-1] + # - LdsBlockSizePerPadB: [-1] # - LdsPadA: [-1] # - LdsPadB: [-1] # - 1LDSBuffer: [1] diff --git a/Tensile/Tests/extended/custom_kernel/ck_dgemm_90a_nn.yaml b/Tensile/Tests/extended/custom_kernel/ck_dgemm_90a_nn.yaml index 6bb9fe680..0bcac773d 100644 --- a/Tensile/Tests/extended/custom_kernel/ck_dgemm_90a_nn.yaml +++ b/Tensile/Tests/extended/custom_kernel/ck_dgemm_90a_nn.yaml @@ -45,6 +45,10 @@ BenchmarkProblems: - VectorAtomicWidth: [1] - VectorWidth: [2] - WorkGroupMapping: [4] + - LdsPadA: [0] + - LdsPadB: [0] + - LdsBlockSizePerPadA: [0] + - LdsBlockSizePerPadB: [0] BenchmarkJoinParameters: BenchmarkFinalParameters: - ProblemSizes: diff --git a/Tensile/Tests/extended/custom_kernel/ck_dgemm_90a_nn_large_offset.yaml b/Tensile/Tests/extended/custom_kernel/ck_dgemm_90a_nn_large_offset.yaml index 956267b3f..4e433fa48 100644 --- a/Tensile/Tests/extended/custom_kernel/ck_dgemm_90a_nn_large_offset.yaml +++ b/Tensile/Tests/extended/custom_kernel/ck_dgemm_90a_nn_large_offset.yaml @@ -46,6 +46,10 @@ BenchmarkProblems: - VectorAtomicWidth: [1] - VectorWidth: [2] - WorkGroupMapping: [4] + - LdsPadA: [0] + - LdsPadB: [0] + - LdsBlockSizePerPadA: [0] + - LdsBlockSizePerPadB: [0] BenchmarkJoinParameters: BenchmarkFinalParameters: - ProblemSizes: diff --git a/Tensile/Tests/extended/direct_to_vgpr/dtv_hgemm.yaml b/Tensile/Tests/extended/direct_to_vgpr/dtv_hgemm.yaml index 7fe05b5f5..1c1212fc2 100644 --- a/Tensile/Tests/extended/direct_to_vgpr/dtv_hgemm.yaml +++ b/Tensile/Tests/extended/direct_to_vgpr/dtv_hgemm.yaml @@ -219,7 +219,7 @@ BenchmarkProblems: - PrefetchGlobalRead: [1]#[1,2] - PrefetchLocalRead: [2,3,5,9] - ScheduleIterAlg: [3] - - StaggerU: [0,32] + #- StaggerU: [0,32] - SourceSwap: [1]#[0,1] - TransposeLDS: [1] - GlobalReadVectorWidth: [2,4,8] @@ -236,6 +236,7 @@ BenchmarkProblems: - GlobalSplitU: [1,2] - GlobalSplitUAlgorithm: ["SingleBuffer"] - VgprForLocalReadPacking: [1] + - ClusterLocalRead: [0,1] BenchmarkJoinParameters: BenchmarkFinalParameters: - ProblemSizes: @@ -391,7 +392,7 @@ BenchmarkProblems: - PrefetchGlobalRead: [1,2] - PrefetchLocalRead: [1,2,3,5,9] - ScheduleIterAlg: [3] - - StaggerU: [0,32] + #- StaggerU: [0,32] - SourceSwap: [1]#[0,1] - TransposeLDS: [1] - GlobalReadVectorWidth: [4,8] @@ -406,6 +407,7 @@ BenchmarkProblems: - GlobalSplitU: [1,2] - GlobalSplitUAlgorithm: ["SingleBuffer"] - VgprForLocalReadPacking: [0,1] + - ClusterLocalRead: [0,1] BenchmarkJoinParameters: BenchmarkFinalParameters: - ProblemSizes: diff --git a/Tensile/Tests/extended/direct_to_vgpr/dtv_igemm.yaml b/Tensile/Tests/extended/direct_to_vgpr/dtv_igemm.yaml index dfdb51d4b..a37995b89 100644 --- a/Tensile/Tests/extended/direct_to_vgpr/dtv_igemm.yaml +++ b/Tensile/Tests/extended/direct_to_vgpr/dtv_igemm.yaml @@ -665,7 +665,7 @@ BenchmarkProblems: - PrefetchGlobalRead: [1]#[1,2] - PrefetchLocalRead: [2,3,5,9] - ScheduleIterAlg: [3] - - StaggerU: [0,32] + #- StaggerU: [0,32] - SourceSwap: [1]#[0,1] - TransposeLDS: [1] - GlobalReadVectorWidth: [4,8,16] @@ -680,6 +680,7 @@ BenchmarkProblems: - GlobalSplitU: [1,2] - GlobalSplitUAlgorithm: ["SingleBuffer"] - VgprForLocalReadPacking: [1] + - ClusterLocalRead: [0,1] BenchmarkJoinParameters: BenchmarkFinalParameters: - ProblemSizes: diff --git a/Tensile/Tests/extended/local_split_u/f8gemm_lsu_mfma.yaml b/Tensile/Tests/extended/local_split_u/f8gemm_lsu_mfma.yaml index 0f0095f08..b276061e8 100644 --- a/Tensile/Tests/extended/local_split_u/f8gemm_lsu_mfma.yaml +++ b/Tensile/Tests/extended/local_split_u/f8gemm_lsu_mfma.yaml @@ -75,6 +75,7 @@ BenchmarkProblems: - BufferLoad: [0,1] - BufferStore: [0,1] - VgprForLocalReadPacking: [0,1] + - ClusterLocalRead: [0,1] BenchmarkJoinParameters: BenchmarkFinalParameters: - ProblemSizes: @@ -365,6 +366,7 @@ BenchmarkProblems: - TransposeLDS: [1] - LocalReadVectorWidth: [-1] - VgprForLocalReadPacking: [0,1] + - ClusterLocalRead: [0,1] BenchmarkJoinParameters: BenchmarkFinalParameters: - ProblemSizes: diff --git a/Tensile/Tests/extended/local_split_u/igemm_lsu_mfma.yaml b/Tensile/Tests/extended/local_split_u/igemm_lsu_mfma.yaml index 892a185e2..e1b715732 100644 --- a/Tensile/Tests/extended/local_split_u/igemm_lsu_mfma.yaml +++ b/Tensile/Tests/extended/local_split_u/igemm_lsu_mfma.yaml @@ -74,6 +74,7 @@ BenchmarkProblems: - BufferLoad: [0,1] - BufferStore: [0,1] - VgprForLocalReadPacking: [0,1] + - ClusterLocalRead: [0,1] BenchmarkJoinParameters: BenchmarkFinalParameters: - ProblemSizes: @@ -256,6 +257,7 @@ BenchmarkProblems: - LocalReadVectorWidth: [-1,8] - DirectToVgprB: [0,1] - VgprForLocalReadPacking: [0,1] + - ClusterLocalRead: [0,1] BenchmarkJoinParameters: BenchmarkFinalParameters: - ProblemSizes: diff --git a/Tensile/Tests/extended/local_split_u/sgemm_lsu_mfma.yaml b/Tensile/Tests/extended/local_split_u/sgemm_lsu_mfma.yaml index 7c6ad52d4..72cf67ed8 100644 --- a/Tensile/Tests/extended/local_split_u/sgemm_lsu_mfma.yaml +++ b/Tensile/Tests/extended/local_split_u/sgemm_lsu_mfma.yaml @@ -48,7 +48,6 @@ BenchmarkProblems: - GlobalSplitUWorkGroupMappingRoundRobin: [False] - GlobalSplitUSummationAssignmentRoundRobin: [True] - 1LDSBuffer: [0] - - LdsBlockSizePerPad: [128] #- AssertFree0ElementMultiple : [8] - AssertFree1ElementMultiple : [1,4] - ExpandPointerSwap: [1] # 1 for DirectToVgpr @@ -280,7 +279,8 @@ BenchmarkProblems: - GlobalSplitUWorkGroupMappingRoundRobin: [False] - GlobalSplitUSummationAssignmentRoundRobin: [True] - 1LDSBuffer: [0] - - LdsBlockSizePerPad: [128] + - LdsBlockSizePerPadA: [-1] + - LdsBlockSizePerPadB: [-1] #- AssertFree0ElementMultiple : [8] - AssertFree1ElementMultiple : [1,4] - ExpandPointerSwap: [1] # 1 for DirectToVgpr diff --git a/Tensile/Tests/extended/stream_k/sk_2tile_hgemm_hhs.yaml b/Tensile/Tests/extended/stream_k/sk_2tile_hgemm_hhs.yaml new file mode 100644 index 000000000..ae148065f --- /dev/null +++ b/Tensile/Tests/extended/stream_k/sk_2tile_hgemm_hhs.yaml @@ -0,0 +1,88 @@ +TestParameters: + marks: [skip-gfx900, skip-gfx906, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102] # not supported by arch + +GlobalParameters: + NumElementsToValidate: -1 + BoundsCheck: True + KernelTime: True + DataInitTypeAlpha: 1 + DataInitTypeBeta: 0 + # DataInitTypeA: 1 + # DataInitTypeB: 1 + # DataInitTypeC: 1 + # ValidationPrintValids: True + MaxWorkspaceSize: 134217728 + # PrintSolutionRejectionReason: True + # ForceGenerateKernel: True + # GenerateSourcesAndExit: True + # NumWarmups: 2 + # EnqueuesPerSync: 10 + # NumBenchmarks: 10 + +BenchmarkProblems: + - # sgemm NT + - # ProblemType + OperationType: GEMM + DataType: h + DestDataType: h + ComputeDataType: s + HighPrecisionAccumulate: True # True if DataType != ComputeDataType + TransposeA: False + TransposeB: True + UseBeta: False + Batched: True + + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + - EdgeType: ["ShiftPtr"] + - PrefetchLocalRead: [True] + ForkParameters: + - MatrixInstruction: + - [32, 32, 8, 1, 1, 4,4, 2,2] + - ThreadTile: + - [ 1, 32 ] + - WorkGroup: + - [ 16, 16, 1 ] + # - WorkGroupMapping: [0, 1, 2, 4, 8, 16, 32, 64] + - WorkGroupMapping: [0, 8] + - GlobalSplitU: [1] + - DepthU: [ 32 ] + # - DepthU: [ 8, 16, 32, 64 ] + - VectorWidth: [1] + - StreamK: [3] + - StaggerU: [0, 32] + - ScheduleIterAlg: [3] + - SourceSwap: [False, True] + # - ExpandPointerSwap: [False, True] + - ExpandPointerSwap: [False] + - PrefetchLocalRead: [5] + # - PrefetchLocalRead: [1, 3, 5, 9, 13, 17] + # - PrefetchGlobalRead: [1, 2] + - PrefetchGlobalRead: [1] + # - 1LDSBuffer: [0, 1] + - 1LDSBuffer: [1] + # - EdgeType: ["Branch", "ShiftPtr"] + - EdgeType: ["ShiftPtr"] + # - MIArchVgpr: [0, 1] + - MIArchVgpr: [0] + # - StoreVectorWidth: [4, 1] + - StoreVectorWidth: [1] + # - NumElementsPerBatchStore: [0, 2, 4, 8] + # - NumElementsPerBatchStore: [8] + - AssertAlphaValue: [1] + - GlobalReadVectorWidth: [8] + - AssertSizeEqual: [{2: 1}] + + BenchmarkForkParameters: + JoinParameters: + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [ [4096], [4096], [1], [1024] ] + - Range: [ [4103], [4096], [1], [1024] ] + - Range: [ [4096], [4103], [1], [1024] ] + - Range: [ [4103], [4103], [1], [1024] ] + - Range: [ [4096], [4096], [1], [1031] ] + - Range: [ [4103], [4103], [1], [1031] ] diff --git a/Tensile/Tests/extended/stream_k/sk_2tile_sgemm.yaml b/Tensile/Tests/extended/stream_k/sk_2tile_sgemm.yaml new file mode 100644 index 000000000..96cb4f808 --- /dev/null +++ b/Tensile/Tests/extended/stream_k/sk_2tile_sgemm.yaml @@ -0,0 +1,119 @@ +TestParameters: + marks: [skip-gfx900, skip-gfx906, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102] # not supported by arch + +GlobalParameters: + NumElementsToValidate: -1 + BoundsCheck: True + KernelTime: True + DataInitTypeAlpha: 1 + DataInitTypeBeta: 0 + # DataInitTypeA: 1 + # DataInitTypeB: 1 + # DataInitTypeC: 1 + # ValidationPrintValids: True + MaxWorkspaceSize: 134217728 + # PrintSolutionRejectionReason: True + # ForceGenerateKernel: True + # GenerateSourcesAndExit: True + # NumWarmups: 2 + # EnqueuesPerSync: 10 + # NumBenchmarks: 10 + # BufferOffsetA: 500 + # BufferOffsetB: 700 + # BufferOffsetC: 900 + # BufferOffsetD: 1100 + +BenchmarkProblems: + - # sgemm NT + - # ProblemType + OperationType: GEMM + DataType: s + TransposeA: False + TransposeB: True + UseBeta: False + Batched: True + + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + - EdgeType: ["ShiftPtr"] + - PrefetchLocalRead: [True] + ForkParameters: + - MatrixInstruction: + - [32, 32, 2, 1, 1, 2,2, 2,2] + # - [32, 32, 2, 1, 1, 3,3, 2,2] + - [32, 32, 2, 1, 1, 4,4, 2,2] + # - [32, 32, 2, 1, 1, 2,1, 2,2] + # - [32, 32, 2, 1, 1, 1,1, 2,2] + # - [16, 16, 4, 1, 1, 3,3, 2,2] + # - [16, 16, 4, 1, 1, 4,1, 2,2] + # - [16, 16, 4, 1, 1, 4,2, 2,2] + - [16, 16, 4, 1, 1, 4,4, 2,2] + - [16, 16, 4, 1, 1, 8,8, 2,2] + # - [16, 16, 4, 1, 1, 2,2, 2,2] + # - [16, 16, 4, 1, 1, 2,1, 2,2] + # - [16, 16, 4, 1, 1, 1,1, 2,2] + - ThreadTile: + - [ 1, 32 ] + - WorkGroup: + - [ 16, 16, 1 ] + # - WorkGroupMapping: [0, 1, 2, 4, 8, 16, 32, 64] # works + - WorkGroupMapping: [0, 8] + - GlobalSplitU: [1] + - DepthU: [ 8 ] + # - DepthU: [ 8, 16, 32 ] + # - DepthU: [ 8, 12, 16, 32 ] + # - DepthU: [ 2, 4, 8, 16, 32, 64 ] + # - DepthU: [ 8, 9, 10, 11, 12, 13, 14, 15, 16 ] # depthu 14 failed a test + - VectorWidth: [1] + - StreamK: [3] + - StaggerU: [0, 32] + # - StaggerU: [0] + - ScheduleIterAlg: [3] + # - SourceSwap: [False, True] + - SourceSwap: [False] + # - ExpandPointerSwap: [False, True] + # - ExpandPointerSwap: [False] + - PrefetchLocalRead: [3, 5] + # - PrefetchLocalRead: [1, 3, 5, 9, 13, 17] + # - PrefetchLocalRead: [1, 9, 10, 11, 12, 13, 14, 15, 16, 17] + # - PrefetchGlobalRead: [1, 2] + - PrefetchGlobalRead: [1] + # - 1LDSBuffer: [0, 1] + # - 1LDSBuffer: [1] + # - EdgeType: ["Branch", "ShiftPtr"] + # - EdgeType: ["ShiftPtr"] + # - MIArchVgpr: [0, 1] + # - MIArchVgpr: [1] + # - StoreVectorWidth: [4, 1] + - StoreVectorWidth: [4] + # - NumElementsPerBatchStore: [0, 2, 4, 8] + # - NumElementsPerBatchStore: [8] + # - AssertAlphaValue: [1] + - AssertSizeEqual: [{2: 1}] + + BenchmarkForkParameters: + JoinParameters: + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + # - Exact: [10912, 10976, 1, 13856] + # - Exact: [16000, 13632, 1, 9040] + # - Range: [ [4096, 7, 6144], [4096, 7, 6144], [1], [1024] ] + # - Range: [ [4096, 31, 5120], [4096, 31, 5120], [1], [1024] ] + # - Range: [ [512], [512], [1], [512] ] + - Range: [ [4096], [4096], [1], [1024] ] + - Range: [ [4103], [4096], [1], [1024] ] + - Range: [ [4096], [4103], [1], [1024] ] + - Range: [ [4103], [4103], [1], [1024] ] + - Range: [ [4096], [4096], [1], [1031] ] + - Range: [ [4103], [4103], [1], [1031] ] + # - Range: [ [4096], [4096], [1], [1024, 1, 1088] ] + # - Range: [ [4096, 31, 5120], [4096, 31, 5120], [1], [1024, 7, 1280] ] + # - Range: [ [4096, 31, 5120], [4096], [1], [1024] ] + # - Range: [ [4096], [4096, 31, 5120], [1], [1024] ] + # - Range: [ [4096], [4096], [1], [1024, 7, 1280] ] + # - Range: [ [1031], [1031], [1], [1031] ] + # - Range: [ [1031], [1031], [8], [1031] ] + # - Range: [ [4096], [4096], [2], [1024] ] diff --git a/Tensile/Tests/extended/stream_k/sk_hgemm_hhs.yaml b/Tensile/Tests/extended/stream_k/sk_hgemm_hhs.yaml new file mode 100644 index 000000000..11398db76 --- /dev/null +++ b/Tensile/Tests/extended/stream_k/sk_hgemm_hhs.yaml @@ -0,0 +1,88 @@ +TestParameters: + marks: [skip-gfx900, skip-gfx906, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102] # not supported by arch + +GlobalParameters: + NumElementsToValidate: -1 + BoundsCheck: True + KernelTime: True + DataInitTypeAlpha: 1 + DataInitTypeBeta: 0 + # DataInitTypeA: 1 + # DataInitTypeB: 1 + # DataInitTypeC: 1 + # ValidationPrintValids: True + MaxWorkspaceSize: 134217728 + # PrintSolutionRejectionReason: True + # ForceGenerateKernel: True + # GenerateSourcesAndExit: True + # NumWarmups: 2 + # EnqueuesPerSync: 10 + # NumBenchmarks: 10 + +BenchmarkProblems: + - # sgemm NT + - # ProblemType + OperationType: GEMM + DataType: h + DestDataType: h + ComputeDataType: s + HighPrecisionAccumulate: True # True if DataType != ComputeDataType + TransposeA: False + TransposeB: True + UseBeta: False + Batched: True + + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + - EdgeType: ["ShiftPtr"] + - PrefetchLocalRead: [True] + ForkParameters: + - MatrixInstruction: + - [32, 32, 8, 1, 1, 4,4, 2,2] + - ThreadTile: + - [ 1, 32 ] + - WorkGroup: + - [ 16, 16, 1 ] + # - WorkGroupMapping: [0, 1, 2, 4, 8, 16, 32, 64] + - WorkGroupMapping: [0, 8] + - GlobalSplitU: [1] + - DepthU: [ 32 ] + # - DepthU: [ 8, 16, 32, 64 ] + - VectorWidth: [1] + - StreamK: [2] + - StaggerU: [0, 32] + - ScheduleIterAlg: [3] + - SourceSwap: [False, True] + # - ExpandPointerSwap: [False, True] + - ExpandPointerSwap: [False] + - PrefetchLocalRead: [5] + # - PrefetchLocalRead: [1, 3, 5, 9, 13, 17] + # - PrefetchGlobalRead: [1, 2] + - PrefetchGlobalRead: [1] + # - 1LDSBuffer: [0, 1] + - 1LDSBuffer: [1] + # - EdgeType: ["Branch", "ShiftPtr"] + - EdgeType: ["ShiftPtr"] + # - MIArchVgpr: [0, 1] + - MIArchVgpr: [0] + # - StoreVectorWidth: [4, 1] + - StoreVectorWidth: [1] + # - NumElementsPerBatchStore: [0, 2, 4, 8] + # - NumElementsPerBatchStore: [8] + - AssertAlphaValue: [1] + - GlobalReadVectorWidth: [8] + - AssertSizeEqual: [{2: 1}] + + BenchmarkForkParameters: + JoinParameters: + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [ [4096], [4096], [1], [1024] ] + - Range: [ [4103], [4096], [1], [1024] ] + - Range: [ [4096], [4103], [1], [1024] ] + - Range: [ [4103], [4103], [1], [1024] ] + - Range: [ [4096], [4096], [1], [1031] ] + - Range: [ [4103], [4103], [1], [1031] ] diff --git a/Tensile/Tests/extended/stream_k/sk_sgemm.yaml b/Tensile/Tests/extended/stream_k/sk_sgemm.yaml new file mode 100644 index 000000000..b6abd6157 --- /dev/null +++ b/Tensile/Tests/extended/stream_k/sk_sgemm.yaml @@ -0,0 +1,119 @@ +TestParameters: + marks: [skip-gfx900, skip-gfx906, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102] # not supported by arch + +GlobalParameters: + NumElementsToValidate: -1 + BoundsCheck: True + KernelTime: True + DataInitTypeAlpha: 1 + DataInitTypeBeta: 0 + # DataInitTypeA: 1 + # DataInitTypeB: 1 + # DataInitTypeC: 1 + # ValidationPrintValids: True + MaxWorkspaceSize: 134217728 + # PrintSolutionRejectionReason: True + # ForceGenerateKernel: True + # GenerateSourcesAndExit: True + # NumWarmups: 2 + # EnqueuesPerSync: 10 + # NumBenchmarks: 10 + # BufferOffsetA: 500 + # BufferOffsetB: 700 + # BufferOffsetC: 900 + # BufferOffsetD: 1100 + +BenchmarkProblems: + - # sgemm NT + - # ProblemType + OperationType: GEMM + DataType: s + TransposeA: False + TransposeB: True + UseBeta: False + Batched: True + + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + - EdgeType: ["ShiftPtr"] + - PrefetchLocalRead: [True] + ForkParameters: + - MatrixInstruction: + - [32, 32, 2, 1, 1, 2,2, 2,2] + # - [32, 32, 2, 1, 1, 3,3, 2,2] + - [32, 32, 2, 1, 1, 4,4, 2,2] + # - [32, 32, 2, 1, 1, 2,1, 2,2] + # - [32, 32, 2, 1, 1, 1,1, 2,2] + # - [16, 16, 4, 1, 1, 3,3, 2,2] + # - [16, 16, 4, 1, 1, 4,1, 2,2] + # - [16, 16, 4, 1, 1, 4,2, 2,2] + - [16, 16, 4, 1, 1, 4,4, 2,2] + - [16, 16, 4, 1, 1, 8,8, 2,2] + # - [16, 16, 4, 1, 1, 2,2, 2,2] + # - [16, 16, 4, 1, 1, 2,1, 2,2] + # - [16, 16, 4, 1, 1, 1,1, 2,2] + - ThreadTile: + - [ 1, 32 ] + - WorkGroup: + - [ 16, 16, 1 ] + # - WorkGroupMapping: [0, 1, 2, 4, 8, 16, 32, 64] # works + - WorkGroupMapping: [0, 8] + - GlobalSplitU: [1] + - DepthU: [ 8 ] + # - DepthU: [ 8, 16, 32 ] + # - DepthU: [ 8, 12, 16, 32 ] + # - DepthU: [ 2, 4, 8, 16, 32, 64 ] + # - DepthU: [ 8, 9, 10, 11, 12, 13, 14, 15, 16 ] # depthu 14 failed a test + - VectorWidth: [1] + - StreamK: [1, 2] + - StaggerU: [0, 32] + # - StaggerU: [0] + - ScheduleIterAlg: [3] + # - SourceSwap: [False, True] + - SourceSwap: [False] + # - ExpandPointerSwap: [False, True] + # - ExpandPointerSwap: [False] + - PrefetchLocalRead: [3, 5] + # - PrefetchLocalRead: [1, 3, 5, 9, 13, 17] + # - PrefetchLocalRead: [1, 9, 10, 11, 12, 13, 14, 15, 16, 17] + # - PrefetchGlobalRead: [1, 2] + - PrefetchGlobalRead: [1] + # - 1LDSBuffer: [0, 1] + # - 1LDSBuffer: [1] + # - EdgeType: ["Branch", "ShiftPtr"] + # - EdgeType: ["ShiftPtr"] + # - MIArchVgpr: [0, 1] + # - MIArchVgpr: [1] + # - StoreVectorWidth: [4, 1] + - StoreVectorWidth: [4] + # - NumElementsPerBatchStore: [0, 2, 4, 8] + # - NumElementsPerBatchStore: [8] + # - AssertAlphaValue: [1] + - AssertSizeEqual: [{2: 1}] + + BenchmarkForkParameters: + JoinParameters: + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + # - Exact: [10912, 10976, 1, 13856] + # - Exact: [16000, 13632, 1, 9040] + # - Range: [ [4096, 7, 6144], [4096, 7, 6144], [1], [1024] ] + # - Range: [ [4096, 31, 5120], [4096, 31, 5120], [1], [1024] ] + # - Range: [ [512], [512], [1], [512] ] + - Range: [ [4096], [4096], [1], [1024] ] + - Range: [ [4103], [4096], [1], [1024] ] + - Range: [ [4096], [4103], [1], [1024] ] + - Range: [ [4103], [4103], [1], [1024] ] + - Range: [ [4096], [4096], [1], [1031] ] + - Range: [ [4103], [4103], [1], [1031] ] + # - Range: [ [4096], [4096], [1], [1024, 1, 1088] ] + # - Range: [ [4096, 31, 5120], [4096, 31, 5120], [1], [1024, 7, 1280] ] + # - Range: [ [4096, 31, 5120], [4096], [1], [1024] ] + # - Range: [ [4096], [4096, 31, 5120], [1], [1024] ] + # - Range: [ [4096], [4096], [1], [1024, 7, 1280] ] + # - Range: [ [1031], [1031], [1], [1031] ] + # - Range: [ [1031], [1031], [8], [1031] ] + # - Range: [ [4096], [4096], [2], [1024] ] diff --git a/Tensile/Tests/pre_checkin/direct_to_vgpr/dtv_sgemm_lite.yaml b/Tensile/Tests/pre_checkin/direct_to_vgpr/dtv_sgemm_lite.yaml index 13bcfab61..942e8d8ff 100644 --- a/Tensile/Tests/pre_checkin/direct_to_vgpr/dtv_sgemm_lite.yaml +++ b/Tensile/Tests/pre_checkin/direct_to_vgpr/dtv_sgemm_lite.yaml @@ -42,7 +42,6 @@ BenchmarkProblems: # - [ 8, 8, 1 ] - DepthU: [16,32] - ExpandPointerSwap: [1] # 1 for DirectToVgpr - - LdsBlockSizePerPad: [-1] - LdsPadB: [4] - PrefetchGlobalRead: [1,2] - GlobalReadPerMfma: [1] diff --git a/Tensile/Tests/pre_checkin/mfma/1LDSB.yaml b/Tensile/Tests/pre_checkin/mfma/1LDSB.yaml index ae5a6ee46..27d2571c5 100644 --- a/Tensile/Tests/pre_checkin/mfma/1LDSB.yaml +++ b/Tensile/Tests/pre_checkin/mfma/1LDSB.yaml @@ -46,7 +46,8 @@ BenchmarkProblems: - ScheduleIterAlg: [3] - InnerUnroll: [1,2,4] - TransposeLDS: [0,1] - - LdsBlockSizePerPad: [-1] + - LdsBlockSizePerPadA: [-1] + - LdsBlockSizePerPadB: [-1] - LdsPadA: [-1] - LdsPadB: [-1] - 1LDSBuffer: [1] @@ -98,7 +99,8 @@ BenchmarkProblems: - ScheduleIterAlg: [3] - InnerUnroll: [1,2,4] - TransposeLDS: [0,1] - - LdsBlockSizePerPad: [-1] + - LdsBlockSizePerPadA: [-1,0] + - LdsBlockSizePerPadB: [-1] - LdsPadA: [-1] - LdsPadB: [-1] - 1LDSBuffer: [1] @@ -145,7 +147,8 @@ BenchmarkProblems: - ScheduleIterAlg: [3] - InnerUnroll: [1,2] - TransposeLDS: [0,1] - - LdsBlockSizePerPad: [-1] + - LdsBlockSizePerPadA: [-1] + - LdsBlockSizePerPadB: [-1,0] - LdsPadA: [-1] - LdsPadB: [-1] - 1LDSBuffer: [1] diff --git a/Tensile/Tests/pre_checkin/mfma/dgemm_gb_global_ldd.yaml b/Tensile/Tests/pre_checkin/mfma/dgemm_gb_global_ldd.yaml index 2403ade3f..f0884415c 100644 --- a/Tensile/Tests/pre_checkin/mfma/dgemm_gb_global_ldd.yaml +++ b/Tensile/Tests/pre_checkin/mfma/dgemm_gb_global_ldd.yaml @@ -64,7 +64,8 @@ BenchmarkProblems: - TransposeLDS: [0] - LdsPadA: [-1] - LdsPadB: [-1] - - LdsBlockSizePerPad: [-1] + - LdsBlockSizePerPadA: [-1] + - LdsBlockSizePerPadB: [-1] - VectorWidth: [2] - WorkGroupMapping: [8] # - AssertSummationElementMultiple: [1, 2] diff --git a/Tensile/Tests/pre_checkin/mfma/wider_local_read.yaml b/Tensile/Tests/pre_checkin/mfma/wider_local_read.yaml index a85f57899..b4dca85f8 100644 --- a/Tensile/Tests/pre_checkin/mfma/wider_local_read.yaml +++ b/Tensile/Tests/pre_checkin/mfma/wider_local_read.yaml @@ -46,7 +46,8 @@ BenchmarkProblems: - ScheduleIterAlg: [3] - InnerUnroll: [1,2,4] - TransposeLDS: [0,1] - - LdsBlockSizePerPad: [128] + - LdsBlockSizePerPadA: [128] + - LdsBlockSizePerPadB: [128] - LdsPadA: [-1] - LdsPadB: [-1] BenchmarkJoinParameters: @@ -97,7 +98,8 @@ BenchmarkProblems: - ScheduleIterAlg: [3] - InnerUnroll: [1,2,4] - TransposeLDS: [0,1] - - LdsBlockSizePerPad: [128] + - LdsBlockSizePerPadA: [128] + - LdsBlockSizePerPadB: [128] - LdsPadA: [-1] - LdsPadB: [-1] BenchmarkJoinParameters: @@ -143,7 +145,8 @@ BenchmarkProblems: - ScheduleIterAlg: [3] - InnerUnroll: [1,2] - TransposeLDS: [0,1] - - LdsBlockSizePerPad: [128] + - LdsBlockSizePerPadA: [128] + - LdsBlockSizePerPadB: [128] - LdsPadA: [-1] - LdsPadB: [-1] BenchmarkJoinParameters: @@ -322,7 +325,8 @@ BenchmarkProblems: - ScheduleIterAlg: [3] - InnerUnroll: [1,2] - TransposeLDS: [0,1] - - LdsBlockSizePerPad: [128] + - LdsBlockSizePerPadA: [-1] + - LdsBlockSizePerPadB: [128] - LdsPadA: [-1] - LdsPadB: [-1] BenchmarkJoinParameters: @@ -373,7 +377,8 @@ BenchmarkProblems: - ScheduleIterAlg: [3] - InnerUnroll: [1,2] - TransposeLDS: [0,1] - - LdsBlockSizePerPad: [128] + - LdsBlockSizePerPadA: [0] + - LdsBlockSizePerPadB: [128] - LdsPadA: [-1] - LdsPadB: [-1] BenchmarkJoinParameters: @@ -419,7 +424,8 @@ BenchmarkProblems: - ScheduleIterAlg: [3] - InnerUnroll: [1,2] - TransposeLDS: [0,1] - - LdsBlockSizePerPad: [128] + - LdsBlockSizePerPadA: [-1] + - LdsBlockSizePerPadB: [128] - LdsPadA: [-1] - LdsPadB: [-1] BenchmarkJoinParameters: diff --git a/Tensile/Tests/pre_checkin/wmma/hgemm_wmma.yaml b/Tensile/Tests/pre_checkin/wmma/hgemm_wmma.yaml index 62af3237d..af4668477 100644 --- a/Tensile/Tests/pre_checkin/wmma/hgemm_wmma.yaml +++ b/Tensile/Tests/pre_checkin/wmma/hgemm_wmma.yaml @@ -51,7 +51,8 @@ BenchmarkProblems: - ScheduleLocalWrite: [1] - ScheduleGlobalRead: [1] - ScheduleIterAlg: [1,3] - - LdsBlockSizePerPad: [128] + - LdsBlockSizePerPadA: [128] + - LdsBlockSizePerPadB: [128] - LdsPadA: [0] - LdsPadB: [8] BenchmarkJoinParameters: diff --git a/Tensile/Tests/pre_checkin/wmma/hpa_bfloat16_gemm_wmma.yaml b/Tensile/Tests/pre_checkin/wmma/hpa_bfloat16_gemm_wmma.yaml index f53a6dd04..071d1e997 100644 --- a/Tensile/Tests/pre_checkin/wmma/hpa_bfloat16_gemm_wmma.yaml +++ b/Tensile/Tests/pre_checkin/wmma/hpa_bfloat16_gemm_wmma.yaml @@ -54,7 +54,8 @@ BenchmarkProblems: - ScheduleLocalWrite: [1] - ScheduleGlobalRead: [1] - ScheduleIterAlg: [1,3] - - LdsBlockSizePerPad: [128] + - LdsBlockSizePerPadA: [128] + - LdsBlockSizePerPadB: [128] - LdsPadA: [0] - LdsPadB: [8] - StaggerUStride: [128,256] diff --git a/Tensile/Tests/pre_checkin/wmma/hpa_hgemm_wmma.yaml b/Tensile/Tests/pre_checkin/wmma/hpa_hgemm_wmma.yaml index 9ff71c0c4..9f73d2685 100644 --- a/Tensile/Tests/pre_checkin/wmma/hpa_hgemm_wmma.yaml +++ b/Tensile/Tests/pre_checkin/wmma/hpa_hgemm_wmma.yaml @@ -52,7 +52,8 @@ BenchmarkProblems: - ScheduleLocalWrite: [1] - ScheduleGlobalRead: [1] - ScheduleIterAlg: [1] - - LdsBlockSizePerPad: [128] + - LdsBlockSizePerPadA: [128] + - LdsBlockSizePerPadB: [128] - LdsPadA: [0] - LdsPadB: [8] - StaggerUStride: [128,256] diff --git a/Tensile/Tests/pre_checkin/wmma/hpa_igemm_wmma.yaml b/Tensile/Tests/pre_checkin/wmma/hpa_igemm_wmma.yaml index b8ff52400..44dd5c225 100644 --- a/Tensile/Tests/pre_checkin/wmma/hpa_igemm_wmma.yaml +++ b/Tensile/Tests/pre_checkin/wmma/hpa_igemm_wmma.yaml @@ -54,7 +54,8 @@ BenchmarkProblems: - ScheduleLocalWrite: [1] - ScheduleGlobalRead: [1] - ScheduleIterAlg: [1] - - LdsBlockSizePerPad: [128] + - LdsBlockSizePerPadA: [128] + - LdsBlockSizePerPadB: [128] - LdsPadA: [0] - LdsPadB: [8] BenchmarkJoinParameters: diff --git a/Tensile/Tests/unit/test_HardwarePredicates.py b/Tensile/Tests/unit/test_HardwarePredicates.py index dbf41799f..23519e695 100644 --- a/Tensile/Tests/unit/test_HardwarePredicates.py +++ b/Tensile/Tests/unit/test_HardwarePredicates.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2020-2022 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2020-2023 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -35,6 +35,9 @@ def test_hardware_predicate_comparison(): c = HardwarePredicate("TruePred") d = HardwarePredicate.FromHardware((9,0,8), 60) e = HardwarePredicate.FromHardware((9,0,8), 64) + f = HardwarePredicate.FromHardware((9,4,2)) + g = HardwarePredicate.FromHardware((9,4,2), isAPU=0) + h = HardwarePredicate.FromHardware((9,4,2), isAPU=1) assert a < b assert a < c @@ -49,6 +52,21 @@ def test_hardware_predicate_comparison(): assert e < c assert e < d + assert g < a + assert g < b + assert g < c + assert g < d + assert g < e + assert g < f + + assert h < a + assert h < b + assert h < c + assert h < d + assert h < e + assert h < f + assert h < g + assert not b < a assert not c < a assert not c < b @@ -56,12 +74,17 @@ def test_hardware_predicate_comparison(): assert not b < e assert not c < e assert not d < e + assert not f < g + assert not g < h assert not a < a assert not b < b assert not c < c assert not d < d assert not e < e + assert not f < f + assert not g < g + assert not h < h def hardware_library_objects_order(): objs = [PredicateLibrary('Hardware', [{'predicate': HardwarePredicate.FromISA((9,0,0))}]), @@ -87,6 +110,29 @@ def test_hardware_library_merge_order(libraries): for r in lib.rows[:-1]: assert r['predicate'] != HardwarePredicate('TruePred') +def hardware_library_objects_order2(): + objs = [PredicateLibrary('Hardware', [{'predicate': HardwarePredicate.FromISA((9,0,6))}]), + PredicateLibrary('Hardware', [{'predicate': HardwarePredicate.FromHardware((9,4,2))}]), + PredicateLibrary('Hardware', [{'predicate': HardwarePredicate.FromHardware((9,4,2), isAPU=0)}]), + PredicateLibrary('Hardware', [{'predicate': HardwarePredicate.FromHardware((9,4,2), isAPU=1)}]), + PredicateLibrary('Hardware', [{'predicate': HardwarePredicate('TruePred')}]) + ] + + return [copy.deepcopy(libs) for libs in itertools.permutations(objs)] + +@pytest.mark.parametrize("libraries", hardware_library_objects_order2()) +def test_hardware_library_merge_order2(libraries): + lib = libraries[0] + for lib2 in libraries[1:]: + lib.merge(lib2) + + assert lib.rows[-1]['predicate'] == HardwarePredicate('TruePred') + assert lib.rows[0]['predicate'] == HardwarePredicate.FromHardware((9,4,2), isAPU=1) + assert lib.rows[1]['predicate'] == HardwarePredicate.FromHardware((9,4,2), isAPU=0) + assert lib.rows[3]['predicate'] == HardwarePredicate.FromHardware((9,4,2)) + for r in lib.rows[:-1]: + assert r['predicate'] != HardwarePredicate('TruePred') + def hardware_library_objects_dups(): objs = [PredicateLibrary('Hardware', [{'predicate': HardwarePredicate.FromISA((9,0,0)), 'library': PredicateLibrary()}]), PredicateLibrary('Hardware', [{'predicate': HardwarePredicate.FromISA((9,0,6)), 'library': PredicateLibrary()}]), diff --git a/Tensile/__init__.py b/Tensile/__init__.py index 2c19a3282..bcbaed7d9 100644 --- a/Tensile/__init__.py +++ b/Tensile/__init__.py @@ -26,7 +26,7 @@ from __future__ import print_function # hardcoded tensile version; also in Tensile/Source/TensileConfigVersion.cmake -__version__ = "4.39.0" +__version__ = "4.40.0" def PrintTensileRoot(): import os.path diff --git a/Tensile/cmake/TensileConfigVersion.cmake b/Tensile/cmake/TensileConfigVersion.cmake index ecce0b65e..8c97bc461 100644 --- a/Tensile/cmake/TensileConfigVersion.cmake +++ b/Tensile/cmake/TensileConfigVersion.cmake @@ -24,7 +24,7 @@ # hardcoded tensile version; also in Tensile/__init__.py set(TENSILE_VERSION_MAJOR 4) -set(TENSILE_VERSION_MINOR 39) +set(TENSILE_VERSION_MINOR 40) set(TENSILE_VERSION_PATCH 0) # export version diff --git a/bump-version.sh b/bump-version.sh index 993a233ab..f1c7ff190 100755 --- a/bump-version.sh +++ b/bump-version.sh @@ -27,8 +27,8 @@ # This script needs to be edited to bump version for new release. # Version will be bumped in Tensile/__init__.py and in .yaml files -OLD_VERSION="4.38.0" -NEW_VERSION="4.39.0" +OLD_VERSION="4.39.0" +NEW_VERSION="4.40.0" OLD_MINIMUM_REQUIRED_VERSION="MinimumRequiredVersion: 4.7.2" NEW_MINIMUM_REQUIRED_VERSION="MinimumRequiredVersion: 4.8.0" diff --git a/pytest.ini b/pytest.ini index 206a3dd03..2dc9a3297 100644 --- a/pytest.ini +++ b/pytest.ini @@ -79,6 +79,7 @@ markers = pbd source stagger_u + stream_k syntax_error tensor_contraction vector_width diff --git a/tuning/automation/rocblas-benchInputCreator.py b/tuning/automation/rocblas-benchInputCreator.py index 8abb7ae59..58dbc1dd6 100644 --- a/tuning/automation/rocblas-benchInputCreator.py +++ b/tuning/automation/rocblas-benchInputCreator.py @@ -25,6 +25,8 @@ # Generates rocblas-bench input files from the library logic files. # creates the benchmark and verification files: # $ python3 rocblas-benchInputCreator.py -v ../libLogics/aldebaran_Cijk_Ailk_Bjlk_BBS_BH.yaml ./ BSS_NT +# creates the benchmark and verification files with hpl initialization: +# $ python3 rocblas-benchInputCreator.py -v -i hpl ../libLogics/aldebaran_Cijk_Ailk_Bjlk_BBS_BH.yaml ./ BSS_NT # creates the benchmark file: # $ python3 rocblas-benchInputCreator.py ../libLogics/aldebaran_Cijk_Ailk_Bjlk_BBS_BH.yaml ./ BSS_NT @@ -32,23 +34,23 @@ import os import yaml - typeIndexToName = {0: "f32_r", 1: "f64_r", 2: "f32_c", 3: "f64_c", 4: "f16_r", 5: "i8_r", 6: "i32_r", 7: "bf16_r", 8: "i8_r", 10: "f8_r", 11: "bf8_r", 12: "f8b8", 13: "b8f8"} - def parseArgs(): argParser = argparse.ArgumentParser() h = {"libLogic" : "Input library logic file", "outDir" : "Output directory for rocBLAS-bench yaml files", "verify" : "Also output verify version of yaml files", - "outfile" : "the name of output file" + "outfile" : "the name of output file", + "initial" : "Matrix initialization: hpl, trig, int. The default is trig for non Int8 datatype, and int for Int8." } argParser.add_argument("libLogic", metavar="logic-file", type=str, help=h["libLogic"]) argParser.add_argument("outDir", metavar="output-dir", type=str, help=h["outDir"]) argParser.add_argument("outfile", metavar="output-file", type=str, help=h["outfile"]) argParser.add_argument("--verify", "-v", action="store_true", help=h["verify"]) + argParser.add_argument("--initialization", "-i", action="store", type=str, default = 'trig', help=h["initial"]) return argParser.parse_args() @@ -163,8 +165,9 @@ def getSizeParams(size, transA, transB): def createYaml(args, problem, sizeMappings, verify): bench = [] benchStrided = [] + benchGeneralBatched = [] - # get GEMM fucnt and matrix orientation - Fixed for each library + # get GEMM function and matrix orientation - Fixed for each library problemParams = getProblemType(problem) transA = problem["TransposeA"] transB = problem["TransposeB"] @@ -177,25 +180,48 @@ def createYaml(args, problem, sizeMappings, verify): else: otherParams = {"alpha": 1, "beta": 1, "iters": 10, "cold_iters": 2} + #initialization + if (args.initialization=='hpl' and problemParams["a_type"]!="i8_r"): + init = {"initialization": "hpl"} + elif (args.initialization=='trig' and problemParams["a_type"]!="i8_r"): + init = {"initialization": "trig_float"} + elif args.initialization== 'int': + init = {"initialization": "rand_int"} + else: + print(f"Initialization {args.initialization} is not allowed for int8 datatype. Initialization changed to rand_int.") + init = {"initialization": "rand_int"} + + # check if the library is General Batched based on the library name + generalBatched = False + if "_GB.yaml" in os.path.split(args.libLogic)[-1]: + generalBatched = True + # create rocBLAS-bench call for each size in logic file for (size, _) in sizeMappings: # size[0] = M, size[1] = N, size[2] = batch_count, size[3] = K, size[4] = ldc, size[5] = ldd, size[6] = lda, size[7] = ldb params = {} - if (size[2] == 1 and not f8gemm): # non-f8, non-batched gemm (serves both HPA and non-HPA) + if (not generalBatched and size[2] == 1 and not f8gemm): # non-f8, non-batched gemm (serves both HPA and non-HPA) params["rocblas_function"] = "rocblas_gemm_ex" - elif (size[2] != 1 and not f8gemm): # non-f8, strided_batched gemm (serves both HPA and non-HPA) + elif (not generalBatched and size[2] != 1 and not f8gemm): # non-f8, strided_batched gemm (serves both HPA and non-HPA) params["rocblas_function"] = "rocblas_gemm_strided_batched_ex" - else: # f8 + elif not generalBatched: # f8 params["rocblas_function"] = "rocblas_gemm_ex3" + elif (generalBatched and not f8gemm): # non-f8, general batched gemm (serves both HPA and non-HPA) currently there is no f8 general batched + params["rocblas_function"] = "rocblas_gemm_batched_ex" + else: + raise RuntimeError(" F8 GEMM is not supporting General Batched.") sizeParams = getSizeParams(size, transA, transB) params.update(problemParams) params.update(sizeParams) params.update(otherParams) + params.update(init) - if size[2] == 1: + if (size[2] == 1 and not generalBatched): bench.append(params) + elif (generalBatched): + benchGeneralBatched.append(params) else: benchStrided.append(params) @@ -204,8 +230,9 @@ def createYaml(args, problem, sizeMappings, verify): prefix += "_verify" if verify else "" benchPath = os.path.join(args.outDir, prefix + "_bench.yaml") - benchStridedPath = os.path.join(args.outDir, prefix +"bench-strided.yaml") - + benchStridedPath = os.path.join(args.outDir, prefix +"_bench-strided.yaml") + benchGeneralBatchedPath = os.path.join(args.outDir, prefix +"_bench-general-batched.yaml") + # write output if len(bench) > 0: with open(benchPath, "w") as f: @@ -213,10 +240,16 @@ def createYaml(args, problem, sizeMappings, verify): if len(benchStrided) > 0: with open(benchStridedPath, "w") as f: yaml.safe_dump(benchStrided, f, default_flow_style=None, sort_keys=False, width=5000) + if len(benchGeneralBatched) > 0: + with open(benchGeneralBatchedPath, "w") as f: + yaml.safe_dump(benchGeneralBatched, f, default_flow_style=None, sort_keys=False, width=5000) def main(): args = parseArgs() + if not (args.initialization in ['hpl', 'trig', 'int']): + raise RuntimeError(f"Initialization {args.initialization} is not allowed. Choose from hpl, trig, or int.") + with open(args.libLogic) as f: logicData = yaml.safe_load(f) diff --git a/tuning_docs/tensile_tuning.tex b/tuning_docs/tensile_tuning.tex index 96772c7e1..7ae619170 100644 --- a/tuning_docs/tensile_tuning.tex +++ b/tuning_docs/tensile_tuning.tex @@ -1013,7 +1013,7 @@ \subsection{Automation Workflow} \begin{tikzpicture} [node distance=3cm, auto,>=latex', thick] \path[->] node[format,label={[align=center]:rocBLAS\\log}] (tune) {test.log}; - \path[->] node[format, align=center,right of=tune,label={[align=center]:creaet\\tuning\\artifacts}] (ptune) {tuning\\artifacts} + \path[->] node[format, align=center,right of=tune,label={[align=center]:create\\tuning\\artifacts}] (ptune) {tuning\\artifacts} (tune) edge node[align=center] {provision\\tuning} (ptune); \path[->] node[format, right of=ptune, label={[align=center]:tuning\\results}] (tune) {logic.yaml} (ptune) edge node {tune} (tune); From f63ff56a33ab1ac3ea352c102e8fa7dc424a02dc Mon Sep 17 00:00:00 2001 From: Babak Date: Fri, 17 Nov 2023 11:38:46 -0500 Subject: [PATCH 2/2] Update CHANGELOG.md for ROCm 6.1.0 --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e34e41583..bd7c1c9ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Change Log for Tensile -## (Unreleased) Tensile 4.40.0 +## Tensile 4.40.0 for ROCm 6.1.0 ### Additions - new DisableKernelPieces values to invalidate local read, local write, and global read - stream-K kernel generation, including two-tile stream-k algorithm by setting StreamK=3