diff --git a/.gitattributes b/.gitattributes deleted file mode 100644 index 41487eb365..0000000000 --- a/.gitattributes +++ /dev/null @@ -1,5 +0,0 @@ -# Enable shader syntax highlighting -*.rahit linguist-language=GLSL -*.rcall linguist-language=GLSL -*.rgen linguist-language=GLSL -*.rint linguist-language=GLSL diff --git a/CMakeLists.txt b/CMakeLists.txt index ba7f747f04..79dfc302b2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -128,6 +128,9 @@ endif() ### VKGC build LLPC ################################################################ if(ICD_BUILD_LLPC) + include("cmake/compilerutils.cmake") + add_compilerutils_projects() + target_include_directories(vkgc INTERFACE ${PROJECT_SOURCE_DIR}/llpc/include diff --git a/cmake/compilerutils.cmake b/cmake/compilerutils.cmake new file mode 100644 index 0000000000..440e63d741 --- /dev/null +++ b/cmake/compilerutils.cmake @@ -0,0 +1,40 @@ +## + ####################################################################################################################### + # + # Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved. + # + # Permission is hereby granted, free of charge, to any person obtaining a copy + # of this software and associated documentation files (the "Software"), to deal + # in the Software without restriction, including without limitation the rights + # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + # copies of the Software, and to permit persons to whom the Software is + # furnished to do so, subject to the following conditions: + # + # The above copyright notice and this permission notice shall be included in all + # copies or substantial portions of the Software. + # + # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + # SOFTWARE. + # + ####################################################################################################################### + +set(LLPC_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/..") + +# Function to add compilerutils as LLVM external projects. +# This appends the project names to LLVM_EXTERNAL_PROJECTS and sets each LLVM_EXTERNAL_*_SOURCE_DIR, +# all in the caller's scope. +macro(add_compilerutils_projects) + if (NOT compilerutils IN_LIST LLVM_EXTERNAL_PROJECTS) + if (NOT llvm_dialects IN_LIST LLVM_EXTERNAL_PROJECTS) + list(APPEND LLVM_EXTERNAL_PROJECTS llvm_dialects) + set(LLVM_EXTERNAL_LLVM_DIALECTS_SOURCE_DIR "${LLPC_SOURCE_DIR}/imported/llvm-dialects") + endif() + list(APPEND LLVM_EXTERNAL_PROJECTS CompilerUtils) + set(LLVM_EXTERNAL_COMPILERUTILS_SOURCE_DIR "${LLPC_SOURCE_DIR}/compilerutils") + endif() +endmacro() diff --git a/cmake/continuations.cmake b/cmake/continuations.cmake index 03e00d6135..8500b5fb17 100644 --- a/cmake/continuations.cmake +++ b/cmake/continuations.cmake @@ -25,10 +25,13 @@ set(LLPC_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/..") +include("${LLPC_SOURCE_DIR}/cmake/compilerutils.cmake") + # Macro to add continuations and its dependencies as LLVM external projects. # This appends the project names to LLVM_EXTERNAL_PROJECTS and sets each LLVM_EXTERNAL_*_SOURCE_DIR, # all in the caller's scope. macro(add_continuations_projects) + add_compilerutils_projects() if (NOT continuations IN_LIST LLVM_EXTERNAL_PROJECTS) if (NOT llvm_dialects IN_LIST LLVM_EXTERNAL_PROJECTS) list(APPEND LLVM_EXTERNAL_PROJECTS llvm_dialects) diff --git a/compilerutils/CMakeLists.txt b/compilerutils/CMakeLists.txt new file mode 100644 index 0000000000..1d97dd27c3 --- /dev/null +++ b/compilerutils/CMakeLists.txt @@ -0,0 +1,37 @@ +cmake_minimum_required(VERSION 3.13.4) + +project(CompilerUtils LANGUAGES CXX) + +function(set_compiler_options PROJECT_NAME) + # Output with color if in terminal: https://github.com/ninja-build/ninja/wiki/FAQ + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + target_compile_options("${PROJECT_NAME}" PRIVATE -fdiagnostics-color=always) + elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + target_compile_options("${PROJECT_NAME}" PRIVATE -fcolor-diagnostics) + endif() +endfunction() + +add_llvm_library(LLVMCompilerUtils + lib/CompilerUtils.cpp + lib/TypeLowering.cpp + + DEPENDS + intrinsics_gen + + LINK_COMPONENTS + Analysis + Core + Support +) + +target_include_directories(LLVMCompilerUtils PUBLIC + $ + $ + $ +) + +target_link_libraries(LLVMCompilerUtils PUBLIC llvm_dialects) +set_compiler_options(LLVMCompilerUtils) + +target_compile_features(LLVMCompilerUtils PUBLIC cxx_std_17) +set_target_properties(LLVMCompilerUtils PROPERTIES CXX_EXTENSIONS OFF) diff --git a/compilerutils/include/compilerutils/CompilerUtils.h b/compilerutils/include/compilerutils/CompilerUtils.h new file mode 100644 index 0000000000..458324de52 --- /dev/null +++ b/compilerutils/include/compilerutils/CompilerUtils.h @@ -0,0 +1,71 @@ +/* + *********************************************************************************************************************** + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + **********************************************************************************************************************/ + +//===- CompilerUtils.h - Library for compiler frontends -------------------===// +// +// Implements several shared helper functions. +// +//===----------------------------------------------------------------------===// + +#ifndef COMPILERUTILS_H +#define COMPILERUTILS_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/IRBuilder.h" + +namespace llvm { + +class CallInst; +class Function; +class Type; +class Value; + +} // namespace llvm + +namespace CompilerUtils { + +// Create an LLVM function call to the named function. The callee is built +// automatically based on return type and its parameters. +// +// @param funcName : Name of the callee +// @param retTy : Return type of the callee +// @param args : Arguments to pass to the callee +// @param attribs : Function attributes +// @param instName : Name to give instruction +llvm::CallInst *createNamedCall(llvm::IRBuilder<> &, llvm::StringRef, llvm::Type *, llvm::ArrayRef, + llvm::ArrayRef, const llvm::Twine & = ""); + +// Modify the function argument types, and return the new function. NOTE: the +// function does not do any uses replacement, so the caller should call +// replaceAllUsesWith() for the function and arguments afterwards. +llvm::Function *mutateFunctionArguments(llvm::Function &, llvm::Type *, const llvm::ArrayRef, + llvm::AttributeList); + +} // namespace CompilerUtils + +#endif diff --git a/lgc/include/lgc/util/TypeLowering.h b/compilerutils/include/compilerutils/TypeLowering.h similarity index 73% rename from lgc/include/lgc/util/TypeLowering.h rename to compilerutils/include/compilerutils/TypeLowering.h index 3eabb86d94..5bec7d1e96 100644 --- a/lgc/include/lgc/util/TypeLowering.h +++ b/compilerutils/include/compilerutils/TypeLowering.h @@ -60,43 +60,49 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/IR/IRBuilder.h" -namespace lgc { - class TypeLowering; /// Given a type, check if it should be replaced. /// -/// Return an empty vector if this function doesn't know how to handle the given type. Subsequent conversion rules will -/// then be considered. +/// Return an empty vector if this function doesn't know how to handle the given +/// type. Subsequent conversion rules will then be considered. /// -/// Otherwise, return a vector with the replacement type(s). If the type is known to remain unchanged, return a -/// singleton vector containing just the original type. +/// Otherwise, return a vector with the replacement type(s). If the type is +/// known to remain unchanged, return a singleton vector containing just the +/// original type. using TypeLoweringFn = llvm::SmallVector(TypeLowering &, llvm::Type *); -/// Given a constant that is known to be meant to be replaced based on its type, attempt to replace it. +/// Given a constant that is known to be meant to be replaced based on its type, +/// attempt to replace it. /// /// Return a non-empty vector if this function was able to handle the constant. /// -/// Otherwise, return an empty vector, and subsequent rules will be applied. Default rules exist for poison, undef, -/// and "null-like" (zeroinitializer etc.). +/// Otherwise, return an empty vector, and subsequent rules will be applied. +/// Default rules exist for poison, undef, and "null-like" (zeroinitializer +/// etc.). using ConstantTypeLoweringFn = llvm::SmallVector(TypeLowering &, llvm::Constant *, llvm::ArrayRef); // ===================================================================================================================== -/// Helper for lowerings that need to replace values of one type by one or more values of another type. +/// Helper for lowerings that need to replace values of one type by one or more +/// values of another type. /// /// This helper really has two parts: /// -/// - A type-level part that applies @ref TypeLoweringFn rules and caches the result -/// - A value-level part that maintains a mapping of replaced values and provides generic handlers for core +/// - A type-level part that applies @ref TypeLoweringFn rules and caches the +/// result +/// - A value-level part that maintains a mapping of replaced values and +/// provides generic handlers for core /// instructions like phi, select, and alloca /// -/// The type-level part can be reused even as the value-level part is cleared by @ref finishCleanup, assuming that the -/// type replacements are consistent (which they might not always be, e.g. where the replacement depends on the target -/// architecture). +/// The type-level part can be reused even as the value-level part is cleared by +/// @ref finishCleanup, assuming that the type replacements are consistent +/// (which they might not always be, e.g. where the replacement depends on the +/// target architecture). /// -/// The value-level part is meant to be used as a nested @ref llvm_dialects::Visitor client. It requires RPO traversal -/// order. Its intended use is along the following lines: +/// The value-level part is meant to be used as a nested @ref +/// llvm_dialects::Visitor client. It requires RPO traversal order. Its intended +/// use is along the following lines: /// @code /// struct MyPayload { /// TypeLowering lowering; @@ -108,8 +114,9 @@ using ConstantTypeLoweringFn = llvm::SmallVector(TypeLowering /// /// MyPayload payload; /// -/// // Reverse post order traversal through functions, replacing instructions with converted types as we go. -/// static const auto visitor = VisitorBuilder +/// // Reverse post order traversal through functions, replacing instructions +/// with converted types as we go. static const auto visitor = +/// VisitorBuilder /// .add(...) /// .nest(&TypeLowering::registerVisitors) /// .build(); @@ -118,42 +125,42 @@ using ConstantTypeLoweringFn = llvm::SmallVector(TypeLowering /// // Fixup phi nodes. /// payload.lowering.finishPhis(); /// -/// // Erase all instructions that "have been replaced" (by calling replaceInstruction for them). -/// payload.lowering.finishCleanup(); +/// // Erase all instructions that "have been replaced" (by calling +/// replaceInstruction for them). payload.lowering.finishCleanup(); /// @endcode class TypeLowering { public: - TypeLowering(llvm::LLVMContext &context); + TypeLowering(llvm::LLVMContext &); llvm::LLVMContext &getContext() const { return m_builder.getContext(); } - void addRule(std::function rule); - void addConstantRule(std::function rule); + void addRule(std::function); + void addConstantRule(std::function); - llvm::ArrayRef convertType(llvm::Type *type); + llvm::ArrayRef convertType(llvm::Type *); - static void registerVisitors(llvm_dialects::VisitorBuilder &builder); + static void registerVisitors(llvm_dialects::VisitorBuilder &); - llvm::SmallVector getValue(llvm::Value *value); - llvm::SmallVector getValueOptional(llvm::Value *value); - void replaceInstruction(llvm::Instruction *inst, llvm::ArrayRef mapping); - void eraseInstruction(llvm::Instruction *inst); + llvm::SmallVector getValue(llvm::Value *); + llvm::SmallVector getValueOptional(llvm::Value *); + void replaceInstruction(llvm::Instruction *, llvm::ArrayRef); + void eraseInstruction(llvm::Instruction *); - llvm::Function *lowerFunctionArguments(llvm::Function &fn); + llvm::Function *lowerFunctionArguments(llvm::Function &); void finishPhis(); bool finishCleanup(); private: - void recordValue(llvm::Value *value, llvm::ArrayRef mapping); - void replaceMappingWith(llvm::Value *toReplace, llvm::Value *with); + void recordValue(llvm::Value *, llvm::ArrayRef); + void replaceMappingWith(llvm::Value *, llvm::Value *); - void visitAlloca(llvm::AllocaInst &alloca); - void visitExtract(llvm::ExtractValueInst &extract); - void visitInsert(llvm::InsertValueInst &insert); - void visitLoad(llvm::LoadInst &load); - void visitPhi(llvm::PHINode &phi); - void visitSelect(llvm::SelectInst &select); - void visitStore(llvm::StoreInst &store); + void visitAlloca(llvm::AllocaInst &); + void visitExtract(llvm::ExtractValueInst &); + void visitInsert(llvm::InsertValueInst &); + void visitLoad(llvm::LoadInst &); + void visitPhi(llvm::PHINode &); + void visitSelect(llvm::SelectInst &); + void visitStore(llvm::StoreInst &); /// Type conversion rules. llvm::SmallVector> m_rules; @@ -170,18 +177,18 @@ class TypeLowering { /// Map original values to type-converted values. /// /// For 1-1 mappings, this stores a value pointer. - /// For 1-N mappings, this stores ((index << 1) | 1), where index is the index into m_convertedValueList at which the - /// converted values can be found. + /// For 1-N mappings, this stores ((index << 1) | 1), where index is the index + /// into m_convertedValueList at which the converted values can be found. llvm::DenseMap m_valueMap; std::vector m_convertedValueList; - /// Reverse map of values that occur as type-converted values to where they occur. The vector elements are either a - /// value pointer (for 1-1 mapped values) or ((index << 1) | 1), where index is the index into m_convertedValueList. + /// Reverse map of values that occur as type-converted values to where they + /// occur. The vector elements are either a value pointer (for 1-1 mapped + /// values) or ((index << 1) | 1), where index is the index into + /// m_convertedValueList. llvm::DenseMap> m_valueReverseMap; std::vector>> m_phis; std::vector m_instructionsToErase; - llvm::SmallVector m_functionToErase; + llvm::SmallVector m_functionsToErase; }; - -} // namespace lgc diff --git a/compilerutils/lib/CompilerUtils.cpp b/compilerutils/lib/CompilerUtils.cpp new file mode 100644 index 0000000000..8d3f0ab589 --- /dev/null +++ b/compilerutils/lib/CompilerUtils.cpp @@ -0,0 +1,100 @@ +/* + *********************************************************************************************************************** + * + * Copyright (c) 2020-2023 Advanced Micro Devices, Inc. All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + **********************************************************************************************************************/ + +#include "compilerutils/CompilerUtils.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" + +using namespace llvm; + +// ===================================================================================================================== +// Create an LLVM function call to the named function. The callee is built +// automatically based on return type and its parameters. +// +// @param funcName : Name of the callee +// @param retTy : Return type of the callee +// @param args : Arguments to pass to the callee +// @param attribs : Function attributes +// @param instName : Name to give instruction +CallInst *CompilerUtils::createNamedCall(IRBuilder<> &builder, StringRef funcName, Type *retTy, ArrayRef args, + ArrayRef attribs, const Twine &instName) { + assert(!funcName.empty()); + Module *mod = builder.GetInsertBlock()->getParent()->getParent(); + Function *func = dyn_cast_or_null(mod->getFunction(funcName)); + if (!func) { + SmallVector argTys; + argTys.reserve(args.size()); + for (auto *arg : args) + argTys.push_back(arg->getType()); + + auto *funcTy = FunctionType::get(retTy, argTys, false); + func = Function::Create(funcTy, GlobalValue::ExternalLinkage, funcName, mod); + + func->setCallingConv(CallingConv::C); + func->addFnAttr(Attribute::NoUnwind); + + for (auto attrib : attribs) { + switch (attrib) { + default: + func->addFnAttr(attrib); + break; + case Attribute::ReadNone: + func->setDoesNotAccessMemory(); + break; + case Attribute::ReadOnly: + func->setOnlyReadsMemory(); + break; + case Attribute::WriteOnly: + func->setOnlyWritesMemory(); + break; + } + } + } + + auto *call = builder.CreateCall(func, args, instName); + call->setCallingConv(CallingConv::C); + call->setAttributes(func->getAttributes()); + + return call; +} + +// Modify the function argument types, and return the new function. NOTE: the +// function does not do any uses replacement, so the caller should call +// replaceAllUsesWith() for the function and arguments afterwards. +Function *CompilerUtils::mutateFunctionArguments(Function &fn, Type *retTy, const ArrayRef argTys, + AttributeList attributes) { + FunctionType *newFnTy = FunctionType::get(retTy, argTys, false); + auto *newFn = Function::Create(newFnTy, fn.getLinkage()); + newFn->copyAttributesFrom(&fn); + newFn->copyMetadata(&fn, 0); + newFn->takeName(&fn); + newFn->setAttributes(attributes); + newFn->splice(newFn->begin(), &fn); + fn.getParent()->getFunctionList().insertAfter(fn.getIterator(), newFn); + return newFn; +} diff --git a/lgc/util/TypeLowering.cpp b/compilerutils/lib/TypeLowering.cpp similarity index 67% rename from lgc/util/TypeLowering.cpp rename to compilerutils/lib/TypeLowering.cpp index 7a0e050ed7..4beadfc032 100644 --- a/lgc/util/TypeLowering.cpp +++ b/compilerutils/lib/TypeLowering.cpp @@ -1,82 +1,109 @@ -#include "lgc/util/TypeLowering.h" -#include "lgc/util/Internal.h" +/* + *********************************************************************************************************************** + * + * Copyright (c) 2020-2023 Advanced Micro Devices, Inc. All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + **********************************************************************************************************************/ + +#include "compilerutils/TypeLowering.h" +#include "compilerutils/CompilerUtils.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Instructions.h" using namespace llvm; -using namespace lgc; namespace { // ===================================================================================================================== -// Fallback converter used by all TypeLowering instances for handling aggregate types. +// Fallback converter used by all TypeLowering instances for handling aggregate +// types. // -// @param typeLowering : the calling TypeLowering object -// @param type : the type to be converted -SmallVector coreTypeConverter(TypeLowering &typeLowering, Type *type) { +// @param typeLower : the calling TypeLowering object +// @param ty : the type to be converted +SmallVector coreTypeConverter(TypeLowering &typeLower, Type *ty) { SmallVector result; - if (auto *arrayType = dyn_cast(type)) { - Type *element = arrayType->getElementType(); - auto converted = typeLowering.convertType(element); - if (converted.size() != 1 || converted[0] != element) { - Type *newElement; + if (auto *arrayTy = dyn_cast(ty)) { + Type *elTy = arrayTy->getElementType(); + auto converted = typeLower.convertType(elTy); + if (converted.size() != 1 || converted[0] != elTy) { + Type *newElTy; if (converted.size() == 1) - newElement = converted[0]; + newElTy = converted[0]; else - newElement = StructType::get(element->getContext(), converted); - result.push_back(ArrayType::get(newElement, arrayType->getNumElements())); + newElTy = StructType::get(elTy->getContext(), converted); + result.push_back(ArrayType::get(newElTy, arrayTy->getNumElements())); return result; } - } else if (auto *structType = dyn_cast(type)) { + } else if (auto *structTy = dyn_cast(ty)) { SmallVector newElements; - newElements.reserve(structType->getNumElements()); + newElements.reserve(structTy->getNumElements()); bool needConversion = false; - for (Type *element : structType->elements()) { - auto converted = typeLowering.convertType(element); - if (converted.size() != 1 || converted[0] != element) + for (Type *elTy : structTy->elements()) { + auto converted = typeLower.convertType(elTy); + if (converted.size() != 1 || converted[0] != elTy) needConversion = true; if (converted.size() == 1) { newElements.push_back(converted[0]); } else { - newElements.push_back(StructType::get(structType->getContext(), converted)); + newElements.push_back(StructType::get(structTy->getContext(), converted)); } } if (needConversion) { - assert(!structType->isPacked()); + assert(!structTy->isPacked()); - if (structType->isLiteral()) { - result.push_back(StructType::get(structType->getContext(), newElements)); + if (structTy->isLiteral()) { + result.push_back(StructType::get(structTy->getContext(), newElements)); } else { - result.push_back(StructType::create(structType->getContext(), newElements, structType->getName())); + result.push_back(StructType::create(structTy->getContext(), newElements, structTy->getName())); } return result; } } - // Since this converter is always called last, we know at this point that the type is not converted. - result.push_back(type); + // Since this converter is always called last, we know at this point that the + // type is not converted. + result.push_back(ty); return result; } // ===================================================================================================================== -// Fallback converter for constants. Provides default handling for poison, undef, and null/zeroinitializer. +// Fallback converter for constants. Provides default handling for poison, +// undef, and null/zeroinitializer. // -// @param typeLowering : the calling TypeLowering object -// @param constant : the constant to be converted -// @param type : the types into which the constant is to be converted -SmallVector coreConstantConverter(TypeLowering &typeLowering, Constant *constant, ArrayRef types) { +// @param typeLower : the calling TypeLowering object +// @param constant : the constant to be Converted +// @param types : the types into which the constant is to be converted +SmallVector coreConstantConverter(TypeLowering &typeLower, Constant *constant, ArrayRef types) { SmallVector result; if (isa(constant)) { - for (Type *type : types) - result.push_back(PoisonValue::get(type)); + for (Type *ty : types) + result.push_back(PoisonValue::get(ty)); } else if (isa(constant)) { - for (Type *type : types) - result.push_back(UndefValue::get(type)); + for (Type *ty : types) + result.push_back(UndefValue::get(ty)); } else if (constant->isNullValue()) { - for (Type *type : types) - result.push_back(Constant::getNullValue(type)); + for (Type *ty : types) + result.push_back(Constant::getNullValue(ty)); } return result; } @@ -93,10 +120,12 @@ TypeLowering::TypeLowering(LLVMContext &context) : m_builder(context) { } // ===================================================================================================================== -// Lower function argument type based on the registered rules. If there is no type remapping needed, will just return -// the old function, otherwise it will move all the instructions in the old function to the new function and return the -// new function. So don't operate on the old function if new function was returned! The old function will be cleaned up -// at the time of TypeLowering::finishCleanup(). +// Lower function argument type based on the registered rules. If there is no +// type remapping needed, will just return the old function, otherwise it will +// move all the instructions in the old function to the new function and return +// the new function. So don't operate on the old function if new function was +// returned! The old function will be cleaned up at the time of +// TypeLowering::finishCleanup(). // Function *TypeLowering::lowerFunctionArguments(Function &fn) { SmallVector newArgTys; @@ -116,7 +145,7 @@ Function *TypeLowering::lowerFunctionArguments(Function &fn) { if (remappedArgs.empty()) return &fn; - auto *newFn = mutateFunctionArguments(fn, fn.getReturnType(), newArgTys, fn.getAttributes()); + auto *newFn = CompilerUtils::mutateFunctionArguments(fn, fn.getReturnType(), newArgTys, fn.getAttributes()); fn.replaceAllUsesWith(newFn); for (unsigned argIdx : remappedArgs) recordValue(fn.getArg(argIdx), {newFn->getArg(argIdx)}); @@ -130,7 +159,7 @@ Function *TypeLowering::lowerFunctionArguments(Function &fn) { if (!llvm::is_contained(remappedArgs, idx)) oldArg->replaceAllUsesWith(newArg); } - m_functionToErase.push_back(&fn); + m_functionsToErase.push_back(&fn); return newFn; } @@ -157,27 +186,28 @@ void TypeLowering::addConstantRule(std::function rule) { // ===================================================================================================================== // Determine the type(s) that a given type should be converted to. // -// For types that *shouldn't* be converted, this returns a singleton array whose only entry is the given type. +// For types that *shouldn't* be converted, this returns a singleton array whose +// only entry is the given type. // -// @param type : the type -ArrayRef TypeLowering::convertType(Type *type) { - auto unaryIt = m_unaryTypeConversions.find(type); +// @param ty : the type +ArrayRef TypeLowering::convertType(Type *ty) { + auto unaryIt = m_unaryTypeConversions.find(ty); if (unaryIt != m_unaryTypeConversions.end()) return ArrayRef(unaryIt->second); - auto multiIt = m_multiTypeConversions.find(type); + auto multiIt = m_multiTypeConversions.find(ty); if (multiIt != m_multiTypeConversions.end()) return multiIt->second; - for (auto rule : reverse(m_rules)) { - SmallVector types = rule(*this, type); + for (const auto &rule : reverse(m_rules)) { + SmallVector types = rule(*this, ty); if (types.empty()) continue; if (types.size() == 1) - return ArrayRef(m_unaryTypeConversions.try_emplace(type, types[0]).first->second); + return ArrayRef(m_unaryTypeConversions.try_emplace(ty, types[0]).first->second); - return m_multiTypeConversions.try_emplace(type, std::move(types)).first->second; + return m_multiTypeConversions.try_emplace(ty, std::move(types)).first->second; } llvm_unreachable("core/fallback rule should prevent us from reaching this point"); @@ -188,25 +218,26 @@ ArrayRef TypeLowering::convertType(Type *type) { // // @param builder : the VisitorBuilder void TypeLowering::registerVisitors(llvm_dialects::VisitorBuilder &builder) { - builder.setStrategy(llvm_dialects::VisitorStrategy::ReversePostOrder); - builder.add(&TypeLowering::visitAlloca); - builder.add(&TypeLowering::visitExtract); - builder.add(&TypeLowering::visitInsert); - builder.add(&TypeLowering::visitLoad); - builder.add(&TypeLowering::visitPhi); - builder.add(&TypeLowering::visitSelect); - builder.add(&TypeLowering::visitStore); + builder.setStrategy(llvm_dialects::VisitorStrategy::ReversePostOrder) + .add(&TypeLowering::visitAlloca) + .add(&TypeLowering::visitExtract) + .add(&TypeLowering::visitInsert) + .add(&TypeLowering::visitLoad) + .add(&TypeLowering::visitPhi) + .add(&TypeLowering::visitSelect) + .add(&TypeLowering::visitStore); } // ===================================================================================================================== // Lookup the mapping of a value that has previously been added. // -// In typical uses of this helper class, the lookup must be successful since instructions are visited in reverse -// post-order, and phi nodes are fixed up at the end. Therefore, this method should be preferred over getValueOptional. +// In typical uses of this helper function, the lookup must be successful since +// instructions are visited in reverse post-order, and phi nodes are fixed up at +// the end. Therefore, this method should be preferred over getValueOptional. // -// @param value : the value -SmallVector TypeLowering::getValue(Value *value) { - auto values = getValueOptional(value); +// @param val : the value +SmallVector TypeLowering::getValue(Value *val) { + auto values = getValueOptional(val); assert(!values.empty()); return values; } @@ -214,16 +245,16 @@ SmallVector TypeLowering::getValue(Value *value) { // ===================================================================================================================== // Lookup a previously added mapping of a given value. // -// Return an empty value list if the given value is unknown, i.e. has not been converted. Most users should use -// getValue instead. +// Return an empty value list if the given value is unknown, val has not been +// converted. Most users should use getValue instead. // // Note that constant conversion is invoked on-the-fly as needed. // -// @param value : the value -SmallVector TypeLowering::getValueOptional(Value *value) { - auto valueIt = m_valueMap.find(value); +// @param val : the value +SmallVector TypeLowering::getValueOptional(Value *val) { + auto valueIt = m_valueMap.find(val); if (valueIt == m_valueMap.end()) { - auto *constant = dyn_cast(value); + auto *constant = dyn_cast(val); if (!constant) return {}; @@ -235,7 +266,7 @@ SmallVector TypeLowering::getValueOptional(Value *value) { if (types.size() == 1 && types[0] == constant->getType()) { converted.push_back(constant); } else { - for (auto rule : reverse(m_constantRules)) { + for (const auto &rule : reverse(m_constantRules)) { SmallVector constants = rule(*this, constant, types); if (!constants.empty()) { converted.insert(converted.end(), constants.begin(), constants.end()); @@ -245,9 +276,9 @@ SmallVector TypeLowering::getValueOptional(Value *value) { assert(!converted.empty() && "missing constant conversion rule"); } - recordValue(value, converted); + recordValue(val, converted); - valueIt = m_valueMap.find(value); + valueIt = m_valueMap.find(val); assert(valueIt != m_valueMap.end()); } @@ -256,18 +287,19 @@ SmallVector TypeLowering::getValueOptional(Value *value) { } size_t begin = valueIt->second >> 1; - auto typeIt = m_multiTypeConversions.find(value->getType()); + auto typeIt = m_multiTypeConversions.find(val->getType()); assert(typeIt != m_multiTypeConversions.end()); size_t count = typeIt->second.size(); return SmallVector(ArrayRef(&m_convertedValueList[begin], count)); } // ===================================================================================================================== -// Record that the value produced by the given instruction should be mapped to the given new value(s), and that the -// instruction should be erased. +// Record that the value produced by the given instruction should be mapped to +// the given new value(s), and that the instruction should be erased. // // @param inst : the instruction -// @param mapping : the value(s) that the value defined by the instruction should be mapped to +// @param mapping : the value(s) that the value defined by the instruction +// should be mapped to void TypeLowering::replaceInstruction(Instruction *inst, ArrayRef mapping) { m_instructionsToErase.push_back(inst); @@ -282,35 +314,36 @@ void TypeLowering::replaceInstruction(Instruction *inst, ArrayRef mappi // ===================================================================================================================== // Record a mapping for a value. // -// @param value : the value for which a mapping is recorded +// @param val : the value for which a mapping is recorded // @param mapping : the mapping that is recorded for the value -void TypeLowering::recordValue(Value *value, ArrayRef mapping) { - assert(!m_valueMap.count(value)); +void TypeLowering::recordValue(Value *val, ArrayRef mapping) { + assert(!m_valueMap.count(val)); if (mapping.size() == 1) { - m_valueMap.try_emplace(value, reinterpret_cast(mapping[0])); + m_valueMap.try_emplace(val, reinterpret_cast(mapping[0])); #ifndef NDEBUG - auto types = convertType(value->getType()); + auto types = convertType(val->getType()); assert(types.size() == 1); assert(types[0] == mapping[0]->getType()); #endif - m_valueReverseMap[mapping[0]].emplace_back(reinterpret_cast(value)); + m_valueReverseMap[mapping[0]].emplace_back(reinterpret_cast(val)); return; } uintptr_t index = m_convertedValueList.size(); uintptr_t code = (index << 1) | 1; m_convertedValueList.insert(m_convertedValueList.end(), mapping.begin(), mapping.end()); - m_valueMap.try_emplace(value, code); + m_valueMap.try_emplace(val, code); for (auto e : llvm::enumerate(mapping)) { m_valueReverseMap[e.value()].emplace_back(((index + e.index()) << 1) | 1); } - // Unconditionally perform the conversion to ensure that it is available in getValue. - auto types = convertType(value->getType()); + // Unconditionally perform the conversion to ensure that it is available in + // getValue. + auto types = convertType(val->getType()); assert(types.size() == mapping.size()); - for (size_t i = 0; i < types.size(); ++i) { - assert(types[i] == mapping[i]->getType()); + for (size_t idx = 0; idx < types.size(); ++idx) { + assert(types[idx] == mapping[idx]->getType()); } } @@ -323,16 +356,19 @@ void TypeLowering::eraseInstruction(llvm::Instruction *inst) { } // ===================================================================================================================== -// Replace a value that may have previously been recorded as part of a mapping with another value. +// Replace a value that may have previously been recorded as part of a mapping +// with another value. // -// This can be used if RAUW is performed after the main traversal of the code, as in: +// This can be used if RAUW is performed after the main traversal of the code, +// as in: // @code // toReplace->replaceAllUsesWith(with); -// typeLowering.replaceMappingWith(toReplace, with); +// typeLower.replaceMappingWith(toReplace, with); // @endcode // // @param toReplace : the mapping value to be replaced -// @param with : the new value to replace it with in all mappings in which it appears +// @param with : the new value to replace it with in all mappings in which +// it appears void TypeLowering::replaceMappingWith(Value *toReplace, Value *with) { if (toReplace == with) return; @@ -363,19 +399,23 @@ void TypeLowering::replaceMappingWith(Value *toReplace, Value *with) { // ===================================================================================================================== // Finalize phi nodes. // -// This performs some trivial simplifications but does not actually erase the old phi nodes yet. +// This performs some trivial simplifications but does not actually erase the +// old phi nodes yet. void TypeLowering::finishPhis() { - // Process phis in reverse order, so that phis from inner loops are handled before phis from outer loops. + // Process phis in reverse order, so that phis from inner loops are handled + // before phis from outer loops. // - // Trivial phis are simplified on-the-fly. Trivial phis can occur when a value is replaced by a tuple of values and - // some of the tuple entries are constant across a loop while others aren't. + // Trivial phis are simplified on-the-fly. Trivial phis can occur when a value + // is replaced by a tuple of values and some of the tuple entries are constant + // across a loop while others aren't. for (const auto &[phi, newPhis] : llvm::reverse(m_phis)) { - // None means no non-self incoming found. nullptr means multiple non-self incomings found. + // None means no non-self incoming found. nullptr means multiple non-self + // incomings found. SmallVector> uniqueNonSelfIncomings; uniqueNonSelfIncomings.resize(newPhis.size()); - for (const auto &[block, value] : llvm::zip(phi->blocks(), phi->incoming_values())) { - auto converted = getValue(value); + for (const auto &[block, val] : llvm::zip(phi->blocks(), phi->incoming_values())) { + auto converted = getValue(val); for (auto [newPhi, newValue, uniqueNonSelf] : llvm::zip(newPhis, converted, uniqueNonSelfIncomings)) { if (newValue != newPhi) { if (!uniqueNonSelf.has_value()) { @@ -396,16 +436,19 @@ void TypeLowering::finishPhis() { } if (Value *replace = uniqueNonSelf.value()) { - // All incomings are either the phi itself or some unique value. This means that unique value must dominate - // the phi and so we can just replace it. + // All incomings are either the phi itself or some unique value. This + // means that unique value must dominate the phi and so we can just + // replace it. newPhi->replaceAllUsesWith(replace); replaceMappingWith(newPhi, replace); eraseInstruction(newPhi); } } - // Phis may be visited *before* the incoming values, which means that finishCleanup() will attempt to delete some - // incoming values *before* the phi. Drop all references so that the incoming values can be deleted without issues. + // Phis may be visited *before* the incoming values, which means that + // finishCleanup() will attempt to delete some incoming values *before* the + // phi. Drop all references so that the incoming values can be deleted + // without issues. phi->dropAllReferences(); } m_phis.clear(); @@ -418,14 +461,15 @@ bool TypeLowering::finishCleanup() { bool changed = !m_instructionsToErase.empty(); - // We can just erase instructions in reverse order since we added them in reverse post-order. + // We can just erase instructions in reverse order since we added them in + // reverse post-order. for (Instruction *inst : llvm::reverse(m_instructionsToErase)) inst->eraseFromParent(); m_instructionsToErase.clear(); - for (Function *fn : m_functionToErase) + for (Function *fn : m_functionsToErase) fn->eraseFromParent(); - m_functionToErase.clear(); + m_functionsToErase.clear(); m_valueMap.clear(); m_convertedValueList.clear(); @@ -470,8 +514,8 @@ void TypeLowering::visitExtract(ExtractValueInst &extract) { if (types.size() == 1) { converted.push_back(newExtract); } else { - for (size_t i = 0; i < types.size(); ++i) - converted.push_back(m_builder.CreateExtractValue(newExtract, i)); + for (size_t idx = 0; idx < types.size(); ++idx) + converted.push_back(m_builder.CreateExtractValue(newExtract, idx)); } replaceInstruction(&extract, converted); @@ -529,11 +573,13 @@ void TypeLowering::visitLoad(LoadInst &load) { loadType = StructType::get(m_builder.getContext(), types); } - // We create an entirely new load instruction and explicitly make no attempt to preserve any assorted data like - // alignment, atomicity, and metadata. Since we are replacing the load of a likely "opaque" type whose size (as far - // as LLVM is concerned) may not even match its replacement, any such data is most likely useless at best and - // incorrect at worst. We should eventually figure out how to handle this properly, but it likely means LLVM - // accepting the notion of "opaque" types to some extent. + // We create an entirely new load instruction and explicitly make no attempt + // to preserve any assorted data like alignment, atomicity, and metadata. + // Since we are replacing the load of a likely "opaque" type whose size (as + // far as LLVM is concerned) may not even match its replacement, any such data + // is most likely useless at best and incorrect at worst. We should eventually + // figure out how to handle this properly, but it likely means LLVM accepting + // the notion of "opaque" Types to some extent. Value *data = m_builder.CreateLoad(loadType, load.getPointerOperand()); data->takeName(&load); @@ -541,8 +587,8 @@ void TypeLowering::visitLoad(LoadInst &load) { if (types.size() == 1) { converted.push_back(data); } else { - for (size_t i = 0; i < types.size(); ++i) - converted.push_back(m_builder.CreateExtractValue(data, i)); + for (size_t idx = 0; idx < types.size(); ++idx) + converted.push_back(m_builder.CreateExtractValue(data, idx)); } replaceInstruction(&load, converted); @@ -591,8 +637,9 @@ void TypeLowering::visitSelect(SelectInst &select) { Value *trueValue = std::get<0>(e.value()); Value *falseValue = std::get<1>(e.value()); - // Simplify selects on the fly. This is relevant when a value is converted into a tuple of values, where some - // entries of the tuple may be more likely to be constant than others. + // Simplify selects on the fly. This is relevant when a value is Converted + // into a tuple of values, where some entries of the tuple may be more + // likely to be constant than others. if (isa(trueValue) || isa(trueValue)) trueValue = falseValue; else if (isa(falseValue) || isa(falseValue)) @@ -623,17 +670,19 @@ void TypeLowering::visitStore(StoreInst &store) { if (values.size() == 1) { data = values[0]; } else { - Type *storeType = StructType::get(m_builder.getContext(), convertType(store.getValueOperand()->getType())); - data = PoisonValue::get(storeType); + Type *storeTy = StructType::get(m_builder.getContext(), convertType(store.getValueOperand()->getType())); + data = PoisonValue::get(storeTy); for (auto e : llvm::enumerate(values)) data = m_builder.CreateInsertValue(data, e.value(), e.index()); } - // We create an entirely new store instruction and explicitly make no attempt to preserve any assorted data like - // alignment, atomicity, and metadata. Since we are replacing the load of a likely "opaque" type whose size (as far - // as LLVM is concerned) may not even match its replacement, any such data is most likely useless at best and - // incorrect at worst. We should eventually figure out how to handle this properly, but it likely means LLVM - // accepting the notion of "opaque" types to some extent. + // We create an entirely new store instruction and explicitly make no attempt + // to preserve any assorted data like alignment, atomicity, and metadata. + // Since we are replacing the load of a likely "opaque" type whose size (as + // far as LLVM is concerned) may not even match its replacement, any such data + // is most likely useless at best and incorrect at worst. We should eventually + // figure out how to handle this properly, but it likely means LLVM accepting + // the notion of "opaque" Types to some extent. m_builder.CreateStore(data, store.getPointerOperand()); replaceInstruction(&store, {}); diff --git a/imported/llvm-dialects b/imported/llvm-dialects index 8c54ca076f..16a0e93317 160000 --- a/imported/llvm-dialects +++ b/imported/llvm-dialects @@ -1 +1 @@ -Subproject commit 8c54ca076fbf841dc5d22da8b6a1d434a01b153c +Subproject commit 16a0e93317979f0b281458a5f3b830e0426983b1 diff --git a/include/khronos/spirv/spirv.hpp b/include/khronos/spirv/spirv.hpp index 421d68432c..1531e8bd33 100644 --- a/include/khronos/spirv/spirv.hpp +++ b/include/khronos/spirv/spirv.hpp @@ -69,6 +69,10 @@ enum SourceLanguage { SourceLanguageHLSL = 5, SourceLanguageCPP_for_OpenCL = 6, SourceLanguageSYCL = 7, + SourceLanguageHERO_C = 8, + SourceLanguageNZSL = 9, + SourceLanguageWGSL = 10, + SourceLanguageSlang = 11, SourceLanguageMax = 0x7fffffff, }; @@ -168,6 +172,11 @@ enum ExecutionMode { ExecutionModeRoundingModeRTZ = 4463, ExecutionModeEarlyAndLateFragmentTestsAMD = 5017, ExecutionModeStencilRefReplacingEXT = 5027, + ExecutionModeCoalescingAMDX = 5069, + ExecutionModeMaxNodeRecursionAMDX = 5071, + ExecutionModeStaticNumWorkgroupsAMDX = 5072, + ExecutionModeShaderIndexAMDX = 5073, + ExecutionModeMaxNumWorkgroupsAMDX = 5077, ExecutionModeStencilRefUnchangedFrontAMD = 5079, ExecutionModeStencilRefGreaterFrontAMD = 5080, ExecutionModeStencilRefLessFrontAMD = 5081, @@ -219,6 +228,8 @@ enum StorageClass { StorageClassImage = 11, StorageClassStorageBuffer = 12, StorageClassTileImageEXT = 4172, + StorageClassNodePayloadAMDX = 5068, + StorageClassNodeOutputPayloadAMDX = 5076, StorageClassCallableDataKHR = 5328, StorageClassCallableDataNV = 5328, StorageClassIncomingCallableDataKHR = 5329, @@ -356,6 +367,8 @@ enum ImageChannelDataType { ImageChannelDataTypeFloat = 14, ImageChannelDataTypeUnormInt24 = 15, ImageChannelDataTypeUnormInt101010_2 = 16, + ImageChannelDataTypeUnsignedIntRaw10EXT = 19, + ImageChannelDataTypeUnsignedIntRaw12EXT = 20, ImageChannelDataTypeMax = 0x7fffffff, }; @@ -517,6 +530,10 @@ enum Decoration { DecorationWeightTextureQCOM = 4487, DecorationBlockMatchTextureQCOM = 4488, DecorationExplicitInterpAMD = 4999, + DecorationNodeSharesPayloadLimitsWithAMDX = 5019, + DecorationNodeMaxPayloadsAMDX = 5020, + DecorationTrackFinishWritingAMDX = 5078, + DecorationPayloadNodeNameAMDX = 5091, DecorationOverrideCoverageNV = 5248, DecorationPassthroughNV = 5250, DecorationViewportRelativeNV = 5252, @@ -584,6 +601,10 @@ enum Decoration { DecorationSingleElementVectorINTEL = 6085, DecorationVectorComputeCallableFunctionINTEL = 6087, DecorationMediaBlockIOINTEL = 6140, + DecorationInitModeINTEL = 6147, + DecorationImplementInRegisterMapINTEL = 6148, + DecorationHostAccessINTEL = 6168, + DecorationFPMaxErrorDecorationINTEL = 6170, DecorationLatencyControlLabelINTEL = 6172, DecorationLatencyControlConstraintINTEL = 6173, DecorationConduitKernelArgumentINTEL = 6175, @@ -595,6 +616,8 @@ enum Decoration { DecorationMMHostInterfaceMaxBurstINTEL = 6181, DecorationMMHostInterfaceWaitRequestINTEL = 6182, DecorationStableKernelArgumentINTEL = 6183, + DecorationCacheControlLoadINTEL = 6442, + DecorationCacheControlStoreINTEL = 6443, DecorationMax = 0x7fffffff, }; @@ -670,6 +693,8 @@ enum BuiltIn { BuiltInBaryCoordSmoothSampleAMD = 4997, BuiltInBaryCoordPullModelAMD = 4998, BuiltInFragStencilRefEXT = 5014, + BuiltInCoalescedInputCountAMDX = 5021, + BuiltInShaderIndexAMDX = 5073, BuiltInViewportMaskNV = 5253, BuiltInSecondaryPositionNV = 5257, BuiltInSecondaryViewportMaskNV = 5258, @@ -723,6 +748,8 @@ enum BuiltIn { BuiltInHitKindNV = 5333, BuiltInCurrentRayTimeNV = 5334, BuiltInHitTriangleVertexPositionsKHR = 5335, + BuiltInHitMicroTriangleVertexPositionsNV = 5337, + BuiltInHitMicroTriangleVertexBarycentricsNV = 5344, BuiltInIncomingRayFlagsKHR = 5351, BuiltInIncomingRayFlagsNV = 5351, BuiltInRayGeometryIndexKHR = 5352, @@ -730,6 +757,8 @@ enum BuiltIn { BuiltInSMCountNV = 5375, BuiltInWarpIDNV = 5376, BuiltInSMIDNV = 5377, + BuiltInHitKindFrontFacingMicroTriangleNV = 5405, + BuiltInHitKindBackFacingMicroTriangleNV = 5406, BuiltInCullMaskKHR = 6021, BuiltInMax = 0x7fffffff, }; @@ -1038,6 +1067,7 @@ enum Capability { CapabilityImageReadWriteLodAMD = 5015, CapabilityInt64ImageEXT = 5016, CapabilityShaderClockKHR = 5055, + CapabilityShaderEnqueueAMDX = 5067, CapabilitySampleMaskOverrideCoverageNV = 5249, CapabilityGeometryShaderPassthroughNV = 5251, CapabilityShaderViewportIndexLayerEXT = 5254, @@ -1097,10 +1127,12 @@ enum Capability { CapabilityFragmentShaderPixelInterlockEXT = 5378, CapabilityDemoteToHelperInvocation = 5379, CapabilityDemoteToHelperInvocationEXT = 5379, + CapabilityDisplacementMicromapNV = 5380, CapabilityRayTracingOpacityMicromapEXT = 5381, CapabilityShaderInvocationReorderNV = 5383, CapabilityBindlessTextureNV = 5390, CapabilityRayQueryPositionFetchKHR = 5391, + CapabilityRayTracingDisplacementMicromapNV = 5409, CapabilitySubgroupShuffleINTEL = 5568, CapabilitySubgroupBufferBlockIOINTEL = 5569, CapabilitySubgroupImageBlockIOINTEL = 5570, @@ -1152,6 +1184,7 @@ enum Capability { CapabilityDotProduct = 6019, CapabilityDotProductKHR = 6019, CapabilityRayCullMaskKHR = 6020, + CapabilityCooperativeMatrixKHR = 6022, CapabilityBitInstructions = 6025, CapabilityGroupNonUniformRotateKHR = 6026, CapabilityAtomicFloat32AddEXT = 6033, @@ -1162,10 +1195,14 @@ enum Capability { CapabilityDebugInfoModuleINTEL = 6114, CapabilityBFloat16ConversionINTEL = 6115, CapabilitySplitBarrierINTEL = 6141, + CapabilityGlobalVariableFPGADecorationsINTEL = 6146, CapabilityFPGAKernelAttributesv2INTEL = 6161, + CapabilityGlobalVariableHostAccessINTEL = 6167, + CapabilityFPMaxErrorINTEL = 6169, CapabilityFPGALatencyControlINTEL = 6171, CapabilityFPGAArgumentInterfacesINTEL = 6174, CapabilityGroupUniformArithmeticKHR = 6400, + CapabilityCacheControlsINTEL = 6441, CapabilityMax = 0x7fffffff, }; @@ -1272,6 +1309,68 @@ enum PackedVectorFormat { PackedVectorFormatMax = 0x7fffffff, }; +enum CooperativeMatrixOperandsShift { + CooperativeMatrixOperandsMatrixASignedComponentsKHRShift = 0, + CooperativeMatrixOperandsMatrixBSignedComponentsKHRShift = 1, + CooperativeMatrixOperandsMatrixCSignedComponentsKHRShift = 2, + CooperativeMatrixOperandsMatrixResultSignedComponentsKHRShift = 3, + CooperativeMatrixOperandsSaturatingAccumulationKHRShift = 4, + CooperativeMatrixOperandsMax = 0x7fffffff, +}; + +enum CooperativeMatrixOperandsMask { + CooperativeMatrixOperandsMaskNone = 0, + CooperativeMatrixOperandsMatrixASignedComponentsKHRMask = 0x00000001, + CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask = 0x00000002, + CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask = 0x00000004, + CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask = 0x00000008, + CooperativeMatrixOperandsSaturatingAccumulationKHRMask = 0x00000010, +}; + +enum CooperativeMatrixLayout { + CooperativeMatrixLayoutRowMajorKHR = 0, + CooperativeMatrixLayoutColumnMajorKHR = 1, + CooperativeMatrixLayoutMax = 0x7fffffff, +}; + +enum CooperativeMatrixUse { + CooperativeMatrixUseMatrixAKHR = 0, + CooperativeMatrixUseMatrixBKHR = 1, + CooperativeMatrixUseMatrixAccumulatorKHR = 2, + CooperativeMatrixUseMax = 0x7fffffff, +}; + +enum InitializationModeQualifier { + InitializationModeQualifierInitOnDeviceReprogramINTEL = 0, + InitializationModeQualifierInitOnDeviceResetINTEL = 1, + InitializationModeQualifierMax = 0x7fffffff, +}; + +enum HostAccessQualifier { + HostAccessQualifierNoneINTEL = 0, + HostAccessQualifierReadINTEL = 1, + HostAccessQualifierWriteINTEL = 2, + HostAccessQualifierReadWriteINTEL = 3, + HostAccessQualifierMax = 0x7fffffff, +}; + +enum LoadCacheControl { + LoadCacheControlUncachedINTEL = 0, + LoadCacheControlCachedINTEL = 1, + LoadCacheControlStreamingINTEL = 2, + LoadCacheControlInvalidateAfterReadINTEL = 3, + LoadCacheControlConstCachedINTEL = 4, + LoadCacheControlMax = 0x7fffffff, +}; + +enum StoreCacheControl { + StoreCacheControlUncachedINTEL = 0, + StoreCacheControlWriteThroughINTEL = 1, + StoreCacheControlWriteBackINTEL = 2, + StoreCacheControlStreamingINTEL = 3, + StoreCacheControlMax = 0x7fffffff, +}; + enum Op { OpNop = 0, OpUndef = 1, @@ -1645,6 +1744,11 @@ enum Op { OpUDotAccSatKHR = 4454, OpSUDotAccSat = 4455, OpSUDotAccSatKHR = 4455, + OpTypeCooperativeMatrixKHR = 4456, + OpCooperativeMatrixLoadKHR = 4457, + OpCooperativeMatrixStoreKHR = 4458, + OpCooperativeMatrixMulAddKHR = 4459, + OpCooperativeMatrixLengthKHR = 4460, OpTypeRayQueryKHR = 4472, OpRayQueryInitializeKHR = 4473, OpRayQueryTerminateKHR = 4474, @@ -1667,6 +1771,9 @@ enum Op { OpFragmentMaskFetchAMD = 5011, OpFragmentFetchAMD = 5012, OpReadClockKHR = 5056, + OpFinalizeNodePayloadsAMDX = 5075, + OpFinishWritingNodePayloadAMDX = 5078, + OpInitializeNodePayloadsAMDX = 5090, OpHitObjectRecordHitMotionNV = 5249, OpHitObjectRecordHitWithIndexMotionNV = 5250, OpHitObjectRecordMissMotionNV = 5251, @@ -1705,6 +1812,8 @@ enum Op { OpSetMeshOutputsEXT = 5295, OpGroupNonUniformPartitionNV = 5296, OpWritePackedPrimitiveIndices4x8NV = 5299, + OpFetchMicroTriangleVertexPositionNV = 5300, + OpFetchMicroTriangleVertexBarycentricNV = 5301, OpReportIntersectionKHR = 5334, OpReportIntersectionNV = 5334, OpIgnoreIntersectionNV = 5335, @@ -2363,6 +2472,11 @@ inline void HasResultAndType(Op opcode, bool *hasResult, bool *hasResultType) { case OpSDotAccSat: *hasResult = true; *hasResultType = true; break; case OpUDotAccSat: *hasResult = true; *hasResultType = true; break; case OpSUDotAccSat: *hasResult = true; *hasResultType = true; break; + case OpTypeCooperativeMatrixKHR: *hasResult = true; *hasResultType = false; break; + case OpCooperativeMatrixLoadKHR: *hasResult = true; *hasResultType = true; break; + case OpCooperativeMatrixStoreKHR: *hasResult = false; *hasResultType = false; break; + case OpCooperativeMatrixMulAddKHR: *hasResult = true; *hasResultType = true; break; + case OpCooperativeMatrixLengthKHR: *hasResult = true; *hasResultType = true; break; case OpTypeRayQueryKHR: *hasResult = true; *hasResultType = false; break; case OpRayQueryInitializeKHR: *hasResult = false; *hasResultType = false; break; case OpRayQueryTerminateKHR: *hasResult = false; *hasResultType = false; break; @@ -2385,6 +2499,9 @@ inline void HasResultAndType(Op opcode, bool *hasResult, bool *hasResultType) { case OpFragmentMaskFetchAMD: *hasResult = true; *hasResultType = true; break; case OpFragmentFetchAMD: *hasResult = true; *hasResultType = true; break; case OpReadClockKHR: *hasResult = true; *hasResultType = true; break; + case OpFinalizeNodePayloadsAMDX: *hasResult = false; *hasResultType = false; break; + case OpFinishWritingNodePayloadAMDX: *hasResult = true; *hasResultType = true; break; + case OpInitializeNodePayloadsAMDX: *hasResult = false; *hasResultType = false; break; case OpHitObjectRecordHitMotionNV: *hasResult = false; *hasResultType = false; break; case OpHitObjectRecordHitWithIndexMotionNV: *hasResult = false; *hasResultType = false; break; case OpHitObjectRecordMissMotionNV: *hasResult = false; *hasResultType = false; break; @@ -2423,6 +2540,8 @@ inline void HasResultAndType(Op opcode, bool *hasResult, bool *hasResultType) { case OpSetMeshOutputsEXT: *hasResult = false; *hasResultType = false; break; case OpGroupNonUniformPartitionNV: *hasResult = true; *hasResultType = true; break; case OpWritePackedPrimitiveIndices4x8NV: *hasResult = false; *hasResultType = false; break; + case OpFetchMicroTriangleVertexPositionNV: *hasResult = true; *hasResultType = true; break; + case OpFetchMicroTriangleVertexBarycentricNV: *hasResult = true; *hasResultType = true; break; case OpReportIntersectionNV: *hasResult = true; *hasResultType = true; break; case OpIgnoreIntersectionNV: *hasResult = false; *hasResultType = false; break; case OpTerminateRayNV: *hasResult = false; *hasResultType = false; break; @@ -2745,6 +2864,10 @@ inline FragmentShadingRateMask operator|(FragmentShadingRateMask a, FragmentShad inline FragmentShadingRateMask operator&(FragmentShadingRateMask a, FragmentShadingRateMask b) { return FragmentShadingRateMask(unsigned(a) & unsigned(b)); } inline FragmentShadingRateMask operator^(FragmentShadingRateMask a, FragmentShadingRateMask b) { return FragmentShadingRateMask(unsigned(a) ^ unsigned(b)); } inline FragmentShadingRateMask operator~(FragmentShadingRateMask a) { return FragmentShadingRateMask(~unsigned(a)); } +inline CooperativeMatrixOperandsMask operator|(CooperativeMatrixOperandsMask a, CooperativeMatrixOperandsMask b) { return CooperativeMatrixOperandsMask(unsigned(a) | unsigned(b)); } +inline CooperativeMatrixOperandsMask operator&(CooperativeMatrixOperandsMask a, CooperativeMatrixOperandsMask b) { return CooperativeMatrixOperandsMask(unsigned(a) & unsigned(b)); } +inline CooperativeMatrixOperandsMask operator^(CooperativeMatrixOperandsMask a, CooperativeMatrixOperandsMask b) { return CooperativeMatrixOperandsMask(unsigned(a) ^ unsigned(b)); } +inline CooperativeMatrixOperandsMask operator~(CooperativeMatrixOperandsMask a) { return CooperativeMatrixOperandsMask(~unsigned(a)); } } // end namespace spv diff --git a/include/vkgcBase.h b/include/vkgcBase.h index a78be2e8fb..ee57be0e91 100644 --- a/include/vkgcBase.h +++ b/include/vkgcBase.h @@ -79,6 +79,7 @@ enum RAYTRACING_ENTRY_FUNC : unsigned { RT_ENTRY_INSTANCE_ID, RT_ENTRY_OBJECT_TO_WORLD_TRANSFORM, RT_ENTRY_WORLD_TO_OBJECT_TRANSFORM, + RT_ENTRY_GET_INSTANCE_NODE, RT_ENTRY_RESERVE1, RT_ENTRY_RESERVE2, RT_ENTRY_FETCH_HIT_TRIANGLE_FROM_NODE_POINTER, diff --git a/include/vkgcDefs.h b/include/vkgcDefs.h index ffab2a2d02..06623523d0 100644 --- a/include/vkgcDefs.h +++ b/include/vkgcDefs.h @@ -46,10 +46,10 @@ #endif /// LLPC major interface version. -#define LLPC_INTERFACE_MAJOR_VERSION 69 +#define LLPC_INTERFACE_MAJOR_VERSION 70 /// LLPC minor interface version. -#define LLPC_INTERFACE_MINOR_VERSION 2 +#define LLPC_INTERFACE_MINOR_VERSION 1 #ifndef LLPC_CLIENT_INTERFACE_MAJOR_VERSION #error LLPC client version is not defined @@ -80,7 +80,8 @@ // %Version History // | %Version | Change Description | // | -------- | ----------------------------------------------------------------------------------------------------- | -// | 69.2 | Add enablePrimGeneratedQuery to PipelineOptions | +// | 70.1 | Add cpsFlags to RayTracingPipelineBuildInfo | +// | 70.0 | Add enablePrimGeneratedQuery to PipelineOptions | // | 69.1 | Add useBarycentric to ShaderModuleUsage | // | 69.0 | Enable continuations transform in LLPC | // | 68.0 | Remove ICache *cache in all PipelineBuildInfo | @@ -356,7 +357,7 @@ static_assert((1 << (ShaderStageCount - 1)) == ShaderStageRayTracingCallableBit, /// Enumerates the binding ID of internal resource. enum InternalBinding : unsigned { FetchShaderBinding = 0, ///< Binding ID of vertex buffer table - ConstantBuffer0Binding = 1, ///< Binding ID of default uniform block + CurrentAttributeBufferBinding = 1, ///< Binding ID of current attribute PushConstantBinding = 2, ///< Binding ID of push constant buffer ShaderRecordBufferBinding = 3, ///< Binding ID of ray-tracing shader record buffer TaskPayloadBinding = 4, ///< Binding ID of payload buffer in task shader @@ -366,7 +367,8 @@ enum InternalBinding : unsigned { RtCaptureReplayInternalBufferBinding = 8, ///< Binding ID of ray-tracing capture replay internal buffer SpecConstInternalBufferBindingId = 9, ///< Binding ID of internal buffer for specialized constant. SpecConstInternalBufferBindingIdEnd = SpecConstInternalBufferBindingId + ShaderStageCount, - CurrentAttributeBufferBinding = 24, ///< Binding ID of current attribute + ConstantBuffer0Binding = 24, ///< Binding ID of default uniform block + ConstantBuffer0BindingEnd = ConstantBuffer0Binding + ShaderStageGfxCount, }; /// Internal vertex attribute location start from 0. @@ -606,8 +608,9 @@ struct PipelineOptions { bool enableCombinedTexture; ///< For OGL only, use the 'set' for DescriptorCombinedTexture ///< for sampled images and samplers bool vertex64BitsAttribSingleLoc; ///< For OGL only, dvec3/dvec4 vertex attrib only consumes 1 location. - bool enablePrimGeneratedQuery; ///< If set, primitive generated query is enabled + bool enableFragColor; ///< For OGL only, need to do frag color broadcast if it is enabled. unsigned reserved20; + bool enablePrimGeneratedQuery; ///< If set, primitive generated query is enabled }; /// Prototype of allocator for output data buffer, used in shader-specific operations. @@ -625,10 +628,12 @@ enum class BinaryType : unsigned { /// Represents resource node data struct ResourceNodeData { ResourceMappingNodeType type; ///< Type of this resource mapping node + unsigned spvId; ///< ID of variable unsigned set; ///< ID of descriptor set unsigned binding; ///< ID of descriptor binding unsigned arraySize; ///< Element count for arrayed binding unsigned location; ///< ID of resource location + bool mergedLocationBinding; ///< TRUE if location and binding are merged in spirv binary unsigned isTexelBuffer; ///< TRUE if it is ImageBuffer or TextureBuffer unsigned isDefaultUniformSampler; ///< TRUE if it's sampler image in default uniform struct BasicType basicType; ///< Type of the variable or element @@ -929,6 +934,9 @@ struct PipelineShaderOptions { /// Application workaround: forward propagate NoContraction decoration to any related FAdd operation. bool forwardPropagateNoContract; + + /// Binding ID offset of default uniform block + unsigned constantBufferBindingOffset; }; /// Represents YCbCr sampler meta data in resource descriptor @@ -1187,6 +1195,12 @@ enum class LlpcRaytracingMode : unsigned { Continuations, // Enable continuation in the new raytracing path }; +// Enumerate feature flags for CPS. +enum CpsFlag : unsigned { + CpsNoFlag = 0, + CpsFlagStackInGlobalMem = 1 << 0, // Put stack in global memory instead of scratch. +}; + /// RayTracing state struct RtState { unsigned nodeStrideShift; ///< Ray tracing BVH node stride @@ -1267,7 +1281,7 @@ struct ApiXfbOutData { XfbOutInfo *pXfbOutInfos; ///< An array of XfbOutInfo items unsigned numXfbOutInfo; ///< Count of XfbOutInfo items bool forceDisableStreamOut; ///< Force to disable stream out XFB outputs -#if LLPC_CLIENT_INTERFACE_MAJOR_VERSION < 69 +#if LLPC_CLIENT_INTERFACE_MAJOR_VERSION < 70 bool forceEnablePrimStats; ///< Force to enable counting generated primitives #endif }; @@ -1422,6 +1436,7 @@ struct RayTracingPipelineBuildInfo { const void *pClientMetadata; ///< Pointer to (optional) client-defined data to be /// stored inside the ELF size_t clientMetadataSize; ///< Size (in bytes) of the client-defined data + unsigned cpsFlags; ///< Cps feature flags }; /// Ray tracing max shader name length diff --git a/lgc/CMakeLists.txt b/lgc/CMakeLists.txt index 8de9a01b15..380d46b2b5 100644 --- a/lgc/CMakeLists.txt +++ b/lgc/CMakeLists.txt @@ -53,7 +53,8 @@ add_llvm_library(LLVMlgc LINK_COMPONENTS Vectorize ) -target_link_libraries(LLVMlgc PUBLIC LLVMContinuations) +llvm_map_components_to_libnames(extra_llvm_libs CompilerUtils Continuations) +target_link_libraries(LLVMlgc PUBLIC ${extra_llvm_libs}) ### Cached Project Options ############################################################################################# option(LLPC_BUILD_NAVI12 "LLPC support for NAVI12?" ON) @@ -236,6 +237,8 @@ target_sources(LLVMlgc PRIVATE patch/VertexFetch.cpp patch/PatchImageOpCollect.cpp patch/RegisterMetadataBuilder.cpp + patch/CombineCooperativeMatrix.cpp + patch/LowerCooperativeMatrix.cpp ) # lgc/state @@ -257,7 +260,6 @@ target_sources(LLVMlgc PRIVATE target_sources(LLVMlgc PRIVATE util/AddressExtender.cpp util/Debug.cpp - util/CpsStackLowering.cpp util/GfxRegHandlerBase.cpp util/GfxRegHandler.cpp util/Internal.cpp @@ -265,7 +267,6 @@ target_sources(LLVMlgc PRIVATE util/ModuleBunch.cpp util/PassManager.cpp util/StartStopTimer.cpp - util/TypeLowering.cpp ) add_subdirectory(disassembler) diff --git a/lgc/builder/BuilderBase.cpp b/lgc/builder/BuilderBase.cpp index e0c2d9833c..21d4caf638 100644 --- a/lgc/builder/BuilderBase.cpp +++ b/lgc/builder/BuilderBase.cpp @@ -30,6 +30,7 @@ */ #include "lgc/util/BuilderBase.h" +#include "compilerutils/CompilerUtils.h" #include "lgc/LgcDialect.h" #include "lgc/state/IntrinsDefs.h" #include "llvm/IR/IntrinsicInst.h" @@ -75,44 +76,7 @@ Value *BuilderCommon::CreatePtrDiff(Type *ty, Value *lhs, Value *rhs, const Twin // @param instName : Name to give instruction CallInst *BuilderCommon::CreateNamedCall(StringRef funcName, Type *retTy, ArrayRef args, ArrayRef attribs, const Twine &instName) { - assert(!funcName.empty()); - Module *module = GetInsertBlock()->getParent()->getParent(); - Function *func = dyn_cast_or_null(module->getFunction(funcName)); - if (!func) { - SmallVector argTys; - argTys.reserve(args.size()); - for (auto arg : args) - argTys.push_back(arg->getType()); - - auto funcTy = FunctionType::get(retTy, argTys, false); - func = Function::Create(funcTy, GlobalValue::ExternalLinkage, funcName, module); - - func->setCallingConv(CallingConv::C); - func->addFnAttr(Attribute::NoUnwind); - - for (auto attrib : attribs) { - switch (attrib) { - default: - func->addFnAttr(attrib); - break; - case Attribute::ReadNone: - func->setDoesNotAccessMemory(); - break; - case Attribute::ReadOnly: - func->setOnlyReadsMemory(); - break; - case Attribute::WriteOnly: - func->setOnlyWritesMemory(); - break; - } - } - } - - auto call = CreateCall(func, args, instName); - call->setCallingConv(CallingConv::C); - call->setAttributes(func->getAttributes()); - - return call; + return CompilerUtils::createNamedCall(*this, funcName, retTy, args, attribs, instName); } // ===================================================================================================================== diff --git a/lgc/builder/MatrixBuilder.cpp b/lgc/builder/MatrixBuilder.cpp index dd5d690ada..fd16770333 100644 --- a/lgc/builder/MatrixBuilder.cpp +++ b/lgc/builder/MatrixBuilder.cpp @@ -342,3 +342,315 @@ Value *BuilderImpl::CreateMatrixInverse(Value *const matrix, const Twine &instNa result->setName(instName); return result; } + +// ===================================================================================================================== +// Convert the element type enum into the corresponding LLVM type. +// +// @param elemType : The element type enum value +// @returns the corresponding LLVM type +Type *BuilderCommon::transCooperativeMatrixElementType(CooperativeMatrixElementType elemType) { + switch (elemType) { + case BuilderCommon::CooperativeMatrixElementType::Float16: + return getHalfTy(); + case BuilderCommon::CooperativeMatrixElementType::Float32: + return getFloatTy(); + case BuilderCommon::CooperativeMatrixElementType::Int16: + return getInt16Ty(); + case BuilderCommon::CooperativeMatrixElementType::Int32: + return getInt32Ty(); + case BuilderCommon::CooperativeMatrixElementType::Int8: + return getInt8Ty(); + default: + llvm_unreachable("The element type is not supported."); + } +} + +// ===================================================================================================================== +// Get the LGC type of a cooperative matrix with the given element type and layout. +// +// @param elemType : the matrix element type +// @param layout : the matrix layout +Type *BuilderCommon::getCooperativeMatrixTy(CooperativeMatrixElementType elemType, CooperativeMatrixLayout layout) { + // Note: the layout currently has no influence on the type. In the long run, we should switch to genuinely opaque + // types at the LGC level, and parameterize the type using both the element type and the layout. + + Type *wordTy = transCooperativeMatrixElementType(elemType)->isIntOrIntVectorTy() ? getInt32Ty() : getFloatTy(); + switch (layout) { + case CooperativeMatrixLayout::Gfx10Accumulator16bitMatrixLayout: + case CooperativeMatrixLayout::Gfx10AccumulatorMatrixLayout: + case CooperativeMatrixLayout::AccumulatorMatrixLayout: + return FixedVectorType::get(wordTy, 8); + case CooperativeMatrixLayout::FactorMatrixLayout: + if (elemType == CooperativeMatrixElementType::Int8) + return FixedVectorType::get(wordTy, 4); + return FixedVectorType::get(wordTy, 8); + default: + llvm_unreachable("Type is not supported!"); + } +} + +// ===================================================================================================================== +// Determine the "length" of a cooperative matrix for purposes of extract/insert operations. +// +// @param elemType : the matrix element type +// @param layout : the matrix layout +// @param instName : name to give instruction(s) +Value *BuilderCommon::CreateCooperativeMatrixLength(CooperativeMatrixElementType elemType, + CooperativeMatrixLayout layout, const Twine &instName) { + Type *resultTy = getInt32Ty(); + Value *args[] = {getInt32(static_cast(elemType)), getInt32(static_cast(layout))}; + std::string callName(lgcName::CooperativeMatrixLength); + addTypeMangling(resultTy, args, callName); + + Value *result = + CreateNamedCall(callName, resultTy, args, {Attribute::ReadNone, Attribute::Speculatable, Attribute::WillReturn}); + result->setName(instName); + return result; +} + +// ===================================================================================================================== +// Create an "extractelement"-equivalent operation for a cooperative matrix value. +// +// @param matrix : the matrix from which to extract an element +// @param index : the index from which to extract +// @param elemType : the matrix element type +// @param layout : the matrix layout +// @param instName : name to give instruction(s) +Value *BuilderCommon::CreateCooperativeMatrixExtract(Value *matrix, Value *index, CooperativeMatrixElementType elemType, + CooperativeMatrixLayout layout, const Twine &instName) { + assert(matrix->getType() == getCooperativeMatrixTy(elemType, layout)); + + Type *resultTy = transCooperativeMatrixElementType(elemType); + Value *args[] = {matrix, index, getInt32(static_cast(elemType)), getInt32(static_cast(layout))}; + std::string callName(lgcName::CooperativeMatrixExtract); + addTypeMangling(resultTy, args, callName); + Value *result = + CreateNamedCall(callName, resultTy, args, {Attribute::ReadNone, Attribute::Speculatable, Attribute::WillReturn}); + result->setName(instName); + return result; +} + +// ===================================================================================================================== +// Create an "insertelement"-equivalent operation for a cooperative matrix value. +// +// @param matrix : the matrix from which to extract an element +// @param index : the index from which to extract +// @param elemType : the matrix element type +// @param layout : the matrix layout +// @param instName : name to give instruction(s) +Value *BuilderCommon::CreateCooperativeMatrixInsert(Value *matrix, Value *value, Value *index, + CooperativeMatrixElementType elemType, + CooperativeMatrixLayout layout, const Twine &instName) { + assert(matrix->getType() == getCooperativeMatrixTy(elemType, layout)); + assert(value->getType() == transCooperativeMatrixElementType(elemType)); + assert(index->getType() == getInt32Ty()); + + Type *resultTy = matrix->getType(); + Value *args[] = {matrix, value, index, getInt32(static_cast(elemType)), + getInt32(static_cast(layout))}; + std::string callName(lgcName::CooperativeMatrixInsert); + addTypeMangling(resultTy, args, callName); + Value *result = + CreateNamedCall(callName, resultTy, args, {Attribute::ReadNone, Attribute::Speculatable, Attribute::WillReturn}); + result->setName(instName); + return result; +} + +// ===================================================================================================================== +// Create cooperative matrix load. +// We only allow the size 16x16 size for a cooperative matrix. So 16 lanes are responsible for reading all data from +// memory. The layout of a cooperative matrix A in the VGPR under wave32 mode is that . Each lane reads a contiguous +// data from memory as a row (or column) of matrix A into the VGPR (implemented as a vector), where A0_0 in one VGPR if +// the data format is f32/i32, A0_0/A0_1 would be in the same VGPR if the data format is f16, A0_0/A0_1/A0_2/A0_3 would +// be in the same VGPR if the data format is i8. +// +// @param pointer : The pointer to a data array. +// @param stride : The number of bytes in memory between the first component of consecutive rows (or columns) in the +// source data. Must be a multiple of the matrix element size. +// @param colMaj : Whether the values loaded from memory are arrayed in column-major or row-major. +// @param elemType : Element type for the matrix. +// @param layout : Identify whether it's A/B or C/D +// @param memoryAccess : Parsed from memory operation. +// @param instName : Name to give instruction(s). +Value *BuilderCommon::CreateCooperativeMatrixLoad(Value *pointer, Value *stride, bool colMajor, + CooperativeMatrixElementType elemType, CooperativeMatrixLayout layout, + unsigned memoryAccess, const Twine &instName) { + Type *resultTy = getCooperativeMatrixTy(elemType, layout); + std::string callName(lgcName::CooperativeMatrixLoad); + Value *args[] = {pointer, + stride, + getInt1(colMajor), + getInt32(static_cast(elemType)), + getInt32(static_cast(layout)), + getInt32(memoryAccess)}; + addTypeMangling(resultTy, args, callName); + Value *loadVal = CreateNamedCall(callName, resultTy, args, {Attribute::ReadOnly}); + loadVal->setName(instName); + return loadVal; +} + +// ===================================================================================================================== +// Create cooperative matrix store. +// We only allow the size 16x16 size for a cooperative matrix. So 16 lanes are responsible for writing matrix elements +// to memory. The layout of a cooperative matrix A in the VGPR under wave32 mode is that each lane writes a row (or +// column) of matrix A from the VGPRs (implemented as a vector) to the memory, where the value of one VGPR is written +// into a memory location if the data format is f32/i32, the value of one VGPR is split into two values to store if +// the data format is f16, the value of one VGPR is split into four values to store if the data format is i8. +// +// @param pointer : The pointer to a data array. +// @param matrix : The row of cooperative matrix to store. +// @param stride : The number of bytes in memory between the first components of consecutive rows (or columns) in the +// destination. Must be a multiple of the element size. +// @param colMaj : Whether the values loaded from memory are arrayed in column-major or row-major. +// @param elemType : Element type for the matrix. +// @param layout : Identify the matrix type(A/B or C). +// @param memoryAccess : Memoray operands +// @param instName : Name to give instruction(s). +Value *BuilderCommon::CreateCooperativeMatrixStore(Value *pointer, Value *matrix, Value *stride, bool colMajor, + CooperativeMatrixElementType elemType, + CooperativeMatrixLayout layout, unsigned memoryAccess, + const Twine &instName) { + assert(matrix->getType() == getCooperativeMatrixTy(elemType, layout)); + + std::string callName(lgcName::CooperativeMatrixStore); + Value *args[] = {pointer, + stride, + getInt1(colMajor), + getInt32(static_cast(elemType)), + getInt32(static_cast(layout)), + getInt32(memoryAccess), + matrix}; + addTypeMangling(Type::getVoidTy(getContext()), args, callName); + + Value *storeVal = + CreateNamedCall(callName, Type::getVoidTy(getContext()), args, {Attribute::WriteOnly, Attribute::WillReturn}); + storeVal->setName(instName); + return nullptr; +} + +// ===================================================================================================================== +// Create cooperative matrix conversion. +// Element-wise-conversion +// @param castOp : The cast Opcode. +// @param source : The source cooperative matrix. +// @param srcElemTy : Source matrix's element type. +// @param dstElemTy : Destination matrix's element type. +// @param srcLayout : Layout for source matrix +// @param dstLayout : Layout for target matrix +// @param instName : Name to give instruction(s). +CallInst *BuilderCommon::CreateCooperativeMatrixConvert(CastInst::CastOps castOp, Value *source, + CooperativeMatrixElementType srcElemTy, + CooperativeMatrixElementType dstElemTy, + CooperativeMatrixLayout srcLayout, + CooperativeMatrixLayout dstLayout, const Twine &instName) { + assert(source->getType() == getCooperativeMatrixTy(srcElemTy, srcLayout)); + + Value *args[] = {getInt32(static_cast(castOp)), source, + getInt32(static_cast(srcElemTy)), getInt32(static_cast(dstElemTy)), + getInt32(static_cast(srcLayout)), getInt32(static_cast(dstLayout))}; + Type *resultTy = getCooperativeMatrixTy(dstElemTy, dstLayout); + std::string callName(lgcName::CooperativeMatrixConvert); + addTypeMangling(resultTy, args, callName); + + CallInst *dstElems = CreateNamedCall(callName, resultTy, args, {Attribute::ReadOnly, Attribute::WillReturn}); + dstElems->setName(instName); + return dstElems; +} +// ===================================================================================================================== +// Create cooperative matrix binary operation +// +// @param coopMatArithOp : The cooperative matrix arithmetic operation to perform. +// @param lhs : The first operand and it can be a scalar or a cooperative matrix. +// @param rhs : The second operand and it should be a cooperative matrix. +// @param elemType : Element type for the matrix. +// @param layout : Layout for the matrix. +// @param instName : Name to give instruction(s). +Value *BuilderCommon::CreateCooperativeMatrixBinaryOp(CooperativeMatrixArithOp coopMatArithOp, Value *lhs, Value *rhs, + CooperativeMatrixElementType elemType, + CooperativeMatrixLayout layout, const Twine &instName) { + assert(lhs->getType() == getCooperativeMatrixTy(elemType, layout)); + assert(lhs->getType() == rhs->getType()); + + std::string callName(lgcName::CooperativeMatrixBinOp); + Value *args[] = {getInt32(static_cast(coopMatArithOp)), lhs, rhs, getInt32(static_cast(elemType)), + getInt32(static_cast(layout))}; + addTypeMangling(rhs->getType(), args, callName); + + Value *result = CreateNamedCall(callName, rhs->getType(), args, {Attribute::ReadOnly, Attribute::WillReturn}); + result->setName(instName); + return result; +} + +// ===================================================================================================================== +// Create cooperative matrix MatrixTimesScalar operation +// +// @param matrix : The first operand and it should be a cooperative matrix. +// @param scalar : The second operand and it should be a scalar. +// @param elemType : The component type of the matrix. +// @param layout : Identify whether it's A/B or C/D +// @param instName : Name to give instruction(s). +Value *BuilderCommon::CreateCoopMatrixTimesScalar(Value *matrix, Value *scalar, CooperativeMatrixElementType elemType, + CooperativeMatrixLayout layout, const Twine &instName) { + assert(matrix->getType() == getCooperativeMatrixTy(elemType, layout)); + assert(scalar->getType() == transCooperativeMatrixElementType(elemType)); + + std::string callName(lgcName::CooperativeMatrixTimesScalar); + Value *args[] = {matrix, scalar, getInt32(static_cast(elemType)), getInt32(static_cast(layout))}; + addTypeMangling(matrix->getType(), args, callName); + + Value *result = CreateNamedCall(callName, matrix->getType(), args, {Attribute::ReadOnly, Attribute::WillReturn}); + result->setName(instName); + return result; +} + +// ===================================================================================================================== +// Create cooperative matrix transpose operation +// +// @param matrix : The first operand and it should be a cooperative matrix. +// @param elemType : The component type of the matrix. +// @param layout : Identify whether it's A/B or C/D +// @param instName : Name to give instruction(s). +CallInst *BuilderCommon::CreateCooperativeMatrixTranspose(llvm::Value *matrix, CooperativeMatrixElementType elemType, + CooperativeMatrixLayout layout, const Twine &instName) { + assert(matrix->getType() == getCooperativeMatrixTy(elemType, layout)); + + std::string callName(lgcName::CooperativeMatrixTranspose); + Value *args[] = {matrix, getInt32(static_cast(elemType)), getInt32(static_cast(layout))}; + addTypeMangling(matrix->getType(), args, callName); + + CallInst *result = CreateNamedCall(callName, matrix->getType(), args, {Attribute::ReadOnly, Attribute::WillReturn}); + result->setName(instName); + return result; +} + +// ===================================================================================================================== +// Create cooperative matrix muladd operation +// +// @param matrixA : Factor cooperative matrix. +// @param matrixB : Factor cooperative matrix. +// @param matrixC : Accumulator cooperative matrix. +// @param isSignedA : Identify the signess for matrix A's element type +// @param isSignedB : Identify the signess for matrix B's element type +// @param isSat : SaturatingAccumulation for calculation +// @param accumElemType : The component type of the accumulator matrix. +// @param factorElemType : The component type of the factor matrix. +Value *BuilderCommon::CreateCooperativeMatrixMulAdd(llvm::Value *matrixA, llvm::Value *matrixB, llvm::Value *matrixC, + bool isSignedA, bool isSignedB, bool isSat, + CooperativeMatrixElementType accumElemType, + CooperativeMatrixElementType factorElemType, + const llvm::Twine &instName) { + std::string callName(lgcName::CooperativeMatrixMulAdd); + Value *args[] = {matrixA, + matrixB, + matrixC, + getInt1(isSignedA), + getInt1(isSignedB), + getInt1(isSat), + getInt32(static_cast(accumElemType)), + getInt32(static_cast(factorElemType))}; + addTypeMangling(matrixC->getType(), args, callName); + + Value *result = CreateNamedCall(callName, matrixC->getType(), args, {Attribute::ReadOnly, Attribute::WillReturn}); + result->setName(instName); + return result; +} diff --git a/lgc/include/lgc/util/CpsStackLowering.h b/lgc/include/lgc/patch/CombineCooperativeMatrix.h similarity index 52% rename from lgc/include/lgc/util/CpsStackLowering.h rename to lgc/include/lgc/patch/CombineCooperativeMatrix.h index e074a7c76c..83dbde3d7e 100644 --- a/lgc/include/lgc/util/CpsStackLowering.h +++ b/lgc/include/lgc/patch/CombineCooperativeMatrix.h @@ -1,7 +1,7 @@ /* *********************************************************************************************************************** * - * Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved. + * Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All Rights Reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -24,51 +24,22 @@ **********************************************************************************************************************/ /** *********************************************************************************************************************** - * @file CpsStackLowering.h - * @brief LLPC header file: contains declaration of class lgc::CpsStackLowering. + * @file CombineCooperativeMatrix.h + * @brief Declares a pass that combines high-level cooperative matrix operations. *********************************************************************************************************************** */ #pragma once - -#include "lgc/LgcCpsDialect.h" -#include "lgc/state/IntrinsDefs.h" -#include "lgc/util/TypeLowering.h" -#include "llvm/ADT/SmallVector.h" +#include "lgc/patch/Patch.h" +#include "llvm/IR/PassManager.h" namespace lgc { -constexpr unsigned continuationStackAlignment = 4; - -inline unsigned getLoweredCpsStackAddrSpace() { - return ADDR_SPACE_PRIVATE; -} -inline unsigned getLoweredCpsStackPointerSize(const llvm::DataLayout &layout) { - return layout.getPointerSize(getLoweredCpsStackAddrSpace()); -} - -class CpsStackLowering { +// ===================================================================================================================== +// Pass to combine cooperative matrix operations. +class CombineCooperativeMatrix : public Patch, public llvm::PassInfoMixin { public: - CpsStackLowering(llvm::LLVMContext &context) : m_typeLowering(context) {} - void lowerCpsStackOps(llvm::Function &function, llvm::Value *); - // Get continuation stack size (in bytes). - unsigned getStackSize() { return m_stackSizeInBytes; } - - TypeLowering m_typeLowering; - -private: - void visitCpsAlloc(cps::AllocOp &alloc); - void visitCpsFree(cps::FreeOp &freeOp); - void visitCpsPeek(cps::PeekOp &peekOp); - void visitSetVsp(cps::SetVspOp &setVsp); - void visitGetVsp(cps::GetVspOp &getVsp); - void visitGetElementPtr(llvm::GetElementPtrInst &getElemPtrInst); - void visitPtrToIntInst(llvm::PtrToIntInst &ptr2Int); - void visitIntToPtrInst(llvm::IntToPtrInst &int2Ptr); - void visitLoad(llvm::LoadInst &load); - void visitStore(llvm::StoreInst &store); + llvm::PreservedAnalyses run(llvm::Function &function, llvm::FunctionAnalysisManager &analysisManager); - llvm::Module *m_module; - llvm::Value *m_cpsStackAlloca; - unsigned m_stackSizeInBytes = 0; + static llvm::StringRef name() { return "lgc-combine-cooperative-matrix"; } }; } // namespace lgc diff --git a/lgc/include/lgc/patch/LowerCooperativeMatrix.h b/lgc/include/lgc/patch/LowerCooperativeMatrix.h new file mode 100644 index 0000000000..f472dd68d7 --- /dev/null +++ b/lgc/include/lgc/patch/LowerCooperativeMatrix.h @@ -0,0 +1,227 @@ +/* + *********************************************************************************************************************** + * + * Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + **********************************************************************************************************************/ +/** + *********************************************************************************************************************** + * @file LowerCooperativeMatrix.h + * @brief LLPC header file : contains declaration of class lgc::LowerCooperativeMatrix.h + *********************************************************************************************************************** + */ +#pragma once +#include "SystemValues.h" +#include "lgc/Builder.h" +#include "lgc/patch/Patch.h" +#include "lgc/state/PipelineShaders.h" +#include "lgc/state/PipelineState.h" +#include "lgc/state/TargetInfo.h" +#include "llvm/IR/Function.h" + +namespace lgc { +// ===================================================================================================================== +// Pass to lower coopMatrix calls +class LowerCooperativeMatrix : public Patch, public llvm::PassInfoMixin { +public: + llvm::PreservedAnalyses run(llvm::Module &module, llvm::ModuleAnalysisManager &analysisManager); + + bool runImpl(llvm::Module &module, PipelineShadersResult &pipelineShaders, PipelineState *pipelineState); + + static llvm::StringRef name() { return "Patch cooperative matrix calls"; } + + void visitCallInst(llvm::CallInst &callInst); + +private: + void processCoopMatrixFunction(llvm::ArrayRef coopMatrixCallees); + + struct TypeProperties { + // Number of (true) elements per lane. + unsigned numFlatElements; + + // Number of (true and unused) elements per lane when casting an LGC dialect cooperative matrix type to + // . + unsigned numMatrixElements; + + // Number of dwords per lane in an LGC dialect cooperative matrix type. + unsigned numMatrixWords; + + // Stride of elements. + unsigned matrixElementStride; + }; + + struct ComputeAddressInfo { + // The base address for the first element in each lane. + llvm::Value *base; + + // The increasing step between the last element in preVgpr and first element in curVgpr. + llvm::Value *macroStep; + + // It will only be set on 16bit@Accumulator@gfx10 like:{C0_0,C1_0;C4_0,C5_0} + llvm::Value *microStep; + + // It will only be set on 16bit @Accumulator @gfx10 like : {C0_0, C1_0; C4_0, C5_0}, which value will + // be 2. + unsigned microCount; + }; + + unsigned getLength(Builder::CooperativeMatrixLayout layout) const; + + TypeProperties getTypeProperties(Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout) const; + + ComputeAddressInfo computeAddressing(Builder::CooperativeMatrixLayout layout, + Builder::CooperativeMatrixElementType elemType, int waveSize, + llvm::Value *stride, bool isColMajor, llvm::Instruction *insertPos); + + llvm::Value *cooperativeMatrixLoadInternal(llvm::Value *dataPtr, llvm::Value *stride, bool colMajor, + Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout, unsigned memoryAccess, + const llvm::Twine &instName, llvm::Instruction *insertPos); + // Convert vector data to cooperativeMatrix vec data + // eg. v16*data_In_Buffer-->v8*coopMatrix_data as two 16bits elements packed. + llvm::Value *convFlatVecToCoopMatrixVec(BuilderCommon &builder, llvm::Value *vecValue, + Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout); + + // Convert cooperativeMatrix vec data to vec data. + llvm::Value *convCoopMatrixVecToFlatVec(BuilderCommon &builder, llvm::Value *matrixValue, + Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout); + + // Create cooperative matrix store operation + void cooperativeMatrixStoreInternal(llvm::Value *dataPtr, llvm::Value *stride, bool colMajor, + Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout, unsigned memoryAccess, + llvm::Value *&vecVal, const llvm::Twine &instName, llvm::Instruction *insertPos); + + // Open-code cooperative matrix extract operation + llvm::Value *cooperativeMatrixExtract(BuilderCommon &builder, llvm::Value *matrix, llvm::Value *index, + Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout); + + // Open-code cooperative matrix insert operation + llvm::Value *cooperativeMatrixInsert(BuilderCommon &builder, llvm::Value *matrix, llvm::Value *value, + llvm::Value *index, Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout); + + // Create cooperative matrix convert operation + llvm::Value *cooperativeMatrixConvert(llvm::CastInst::CastOps castOp, llvm::Value *source, + Builder::CooperativeMatrixElementType srcElemType, + Builder::CooperativeMatrixElementType dstElemType, + Builder::CooperativeMatrixLayout srclayout, + Builder::CooperativeMatrixLayout dstlayout, const llvm::Twine &instName, + llvm::Instruction *insertPos); + + // Create cooperative matrix convert operation without reshape operation + llvm::Value *cooperativeMatrixConvertInternal(llvm::CastInst::CastOps castOp, llvm::Value *source, + Builder::CooperativeMatrixElementType srcElemType, + Builder::CooperativeMatrixElementType dstElemType, + const llvm::Twine &instName, llvm::Instruction *insertPos); + + // Create cooperative matrix binary operation + llvm::Value *cooperativeMatrixBinaryOp(Builder::CooperativeMatrixArithOp coopMatArithOp, llvm::Value *lhs, + llvm::Value *rhs, Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout, const llvm::Twine &instName, + llvm::Instruction *insertPos); + + // Create cooperative matrixTimeScalar operation + llvm::Value *coopMatrixTimesScalar(llvm::Value *matrix, llvm::Value *scalar, + Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout, const llvm::Twine &instName, + llvm::Instruction *insertPos); + + // Create cooperative matrix reshape operation for 16bit on gfx10 and gfx110 + llvm::Value *cooperativeMatrixReshape16BitElementGfx1011(llvm::Value *matrix, + Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout srcLayout, + Builder::CooperativeMatrixLayout dstLayout, + llvm::Value *threadId, const llvm::Twine &instName, + llvm::Instruction *insertPos); + + // Create cooperative matrix reshape operation for 8bit on gfx10 and gfx11 + llvm::Value *cooperativeMatrixReshapeBetween8bitAnd32bitElementGfx1011( + llvm::Value *matrix, Builder::CooperativeMatrixElementType srcElemType, + Builder::CooperativeMatrixLayout srcLayout, const llvm::Twine &instName, llvm::Instruction *insertPos); + + // Adjust the layout on accumulator for gfx10 + llvm::Value *cooperativeMatrixReshapeBetween16bitAnd32bitOnAccGfx10( + llvm::Value *source, Builder::CooperativeMatrixElementType srcElemType, + Builder::CooperativeMatrixElementType dstElemType, Builder::CooperativeMatrixLayout layout, + llvm::Value *isEvenGroup, const llvm::Twine &instName, llvm::Instruction *insertPos); + + // Adjust the layout before reshape operation(eg:float16->float32) + llvm::Value *cooperativeMatrixReshapeBeforeConvert(llvm::Value *source, + Builder::CooperativeMatrixElementType srcElemType, + Builder::CooperativeMatrixElementType dstElemType, + Builder::CooperativeMatrixLayout srcLayout, + Builder::CooperativeMatrixLayout dstLayout, + const llvm::Twine &instName, llvm::Instruction *insertPos); + + // Adjust the layout before reshape operation(eg:float32->float16) + llvm::Value *cooperativeMatrixReshapeAfterConvert(llvm::Value *source, + Builder::CooperativeMatrixElementType srcElemType, + Builder::CooperativeMatrixElementType dstElemType, + Builder::CooperativeMatrixLayout srcLayout, + Builder::CooperativeMatrixLayout dstLayout, + const llvm::Twine &instName, llvm::Instruction *insertPos); + + // Create cooperative matrix transpose operation + llvm::Value *cooperativeMatrixTranspose(llvm::Value *matrix, Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout srcLayout, const llvm::Twine &instName, + llvm::Instruction *insertPos); + + llvm::Value *transposeCooperativeMatrixRecursively(llvm::Value *matrix, unsigned vecStride, unsigned laneStride, + llvm::Value *threadId, BuilderBase &builder); + + // Create cooperative matrix muladd operation + llvm::Value *cooperativeMatrixMulAdd(llvm::Value *copMatrixa, llvm::Value *copMatrixb, llvm::Value *copMatrixc, + bool isSignedA, bool isSignedB, bool isSat, + Builder::CooperativeMatrixElementType accumElemType, + Builder::CooperativeMatrixElementType factorElemType, + const llvm::Twine &instName, llvm::Instruction *insertPos); + + // Simulating for WMMA + llvm::Value *createDotProductFp16Fp16(llvm::Value *const vector1, llvm::Value *const vector2, + llvm::Value *const accumulator, bool isSat, const llvm::Twine &instName, + llvm::Instruction *insertPos); + llvm::Value *createDotProductFp16Fp32(llvm::Value *const vector1, llvm::Value *const vector2, + llvm::Value *const accumulator, bool isSat, const llvm::Twine &instName, + llvm::Instruction *insertPos); + llvm::Value *createDotProductInt16Int32(llvm::Value *vector1, llvm::Value *vector2, llvm::Value *accumulator, + unsigned flags, bool isSat, const llvm::Twine &instName, + llvm::Instruction *insertPos); + llvm::Value *createDotProductInt8Int32(llvm::Value *vector1, llvm::Value *vector2, llvm::Value *accumulator, + unsigned flags, bool isSat, const llvm::Twine &instName, + llvm::Instruction *insertPos); + llvm::Value *createDotProductInt16Int16(llvm::Value *vector1, llvm::Value *vector2, llvm::Value *accumulator, + unsigned flags, bool isSat, const llvm::Twine &instName, + llvm::Instruction *insertPos); + + llvm::Value *getLaneNumber(BuilderBase &builder); + + llvm::SmallVector m_coopMatrixCalls; + PipelineState *m_pipelineState = nullptr; + PipelineShadersResult *m_pipelineShaders = nullptr; + GfxIpVersion m_gfxIp; +}; + +} // namespace lgc diff --git a/lgc/include/lgc/patch/PatchBufferOp.h b/lgc/include/lgc/patch/PatchBufferOp.h index 29a7f65ba7..3f9345425c 100644 --- a/lgc/include/lgc/patch/PatchBufferOp.h +++ b/lgc/include/lgc/patch/PatchBufferOp.h @@ -30,6 +30,7 @@ */ #pragma once +#include "compilerutils/TypeLowering.h" #include "lgc/patch/Patch.h" #include "llvm-dialects/Dialect/Visitor.h" #include "llvm/ADT/DenseMap.h" @@ -53,7 +54,6 @@ class BufferDescToPtrOp; class BufferLengthOp; class BufferPtrDiffOp; class PipelineState; -class TypeLowering; // ===================================================================================================================== // Helper class for lowering buffer operations integrated with a flow based on llvm_dialects::Visitor and TypeLowering. diff --git a/lgc/include/lgc/patch/PatchEntryPointMutate.h b/lgc/include/lgc/patch/PatchEntryPointMutate.h index 40fd5ae48b..f7674d0c1b 100644 --- a/lgc/include/lgc/patch/PatchEntryPointMutate.h +++ b/lgc/include/lgc/patch/PatchEntryPointMutate.h @@ -30,15 +30,17 @@ */ #pragma once +#include "compilerutils/TypeLowering.h" +#include "continuations/CpsStackLowering.h" #include "lgc/LgcCpsDialect.h" #include "lgc/LgcDialect.h" #include "lgc/patch/Patch.h" #include "lgc/patch/ShaderInputs.h" #include "lgc/state/PipelineShaders.h" #include "lgc/state/PipelineState.h" -#include "lgc/util/TypeLowering.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/IRBuilder.h" +#include namespace lgc { @@ -200,6 +202,7 @@ class PatchEntryPointMutate : public Patch, public llvm::PassInfoMixin m_funcCpsStackMap; llvm::Intrinsic::ID m_setInactiveChainArgId; + std::unique_ptr stackLowering; }; } // namespace lgc diff --git a/lgc/include/lgc/state/Defs.h b/lgc/include/lgc/state/Defs.h index e92ca561e9..630fdb1f7c 100644 --- a/lgc/include/lgc/state/Defs.h +++ b/lgc/include/lgc/state/Defs.h @@ -73,6 +73,18 @@ const static char CopyShaderEntryPoint[] = "lgc.shader.COPY.main"; const static char NullFsEntryPoint[] = "lgc.shader.FS.null.main"; const static char TcsPassthroughEntryPoint[] = "lgc.shader.TCS.passthrough.main"; +const static char CooperativeMatrix[] = "lgc.cooperative.matrix"; +const static char CooperativeMatrixLength[] = "lgc.cooperative.matrix.length"; +const static char CooperativeMatrixExtract[] = "lgc.cooperative.matrix.extract"; +const static char CooperativeMatrixInsert[] = "lgc.cooperative.matrix.insert"; +const static char CooperativeMatrixLoad[] = "lgc.cooperative.matrix.load"; +const static char CooperativeMatrixStore[] = "lgc.cooperative.matrix.store"; +const static char CooperativeMatrixConvert[] = "lgc.cooperative.matrix.convert"; +const static char CooperativeMatrixBinOp[] = "lgc.cooperative.matrix.binop"; +const static char CooperativeMatrixTimesScalar[] = "lgc.cooperative.matrix.times.scalar"; +const static char CooperativeMatrixTranspose[] = "lgc.cooperative.matrix.transpose"; +const static char CooperativeMatrixMulAdd[] = "lgc.cooperative.matrix.muladd"; + } // namespace lgcName // Value for high half of address that means "use PC". diff --git a/lgc/include/lgc/util/Internal.h b/lgc/include/lgc/util/Internal.h index 5726619f56..caad44c1a8 100644 --- a/lgc/include/lgc/util/Internal.h +++ b/lgc/include/lgc/util/Internal.h @@ -96,9 +96,4 @@ bool isDontCareValue(llvm::Value *value); // type in a return value struct, ensuring it gets into VGPRs. llvm::Type *getVgprTy(llvm::Type *ty); -// Modify the function argument types, and return the new function. NOTE: the function does not do any uses -// replacement, so the caller should call replaceAllUsesWith() for the function and arguments afterwards. -llvm::Function *mutateFunctionArguments(llvm::Function &fn, llvm::Type *retTy, - const llvm::ArrayRef argTys, llvm::AttributeList attributes); - } // namespace lgc diff --git a/lgc/interface/lgc/BuilderCommon.h b/lgc/interface/lgc/BuilderCommon.h index bef986246c..f006768365 100644 --- a/lgc/interface/lgc/BuilderCommon.h +++ b/lgc/interface/lgc/BuilderCommon.h @@ -83,6 +83,157 @@ class BuilderCommon : public llvm_dialects::Builder { // @param instName : Name to give instruction llvm::CallInst *CreateNamedCall(llvm::StringRef funcName, llvm::Type *retTy, llvm::ArrayRef args, llvm::ArrayRef attribs, const llvm::Twine &instName = ""); + + // ----------------------------------------------------------------------------------------------------------------- + // Cooperative matrix operation. + + enum CooperativeMatrixMemoryAccess { + MemoryAccessMaskNone = 0x00, // No mask + MemoryAccessVolatileMask = 0x01, // Access memory in volatile + MemoryAccessCoherentMask = 0x02, // Access memory in coherent + MemoryAccessTemporalMask = 0x04, // Access memory in temporal + }; + + enum CooperativeMatrixElementType { + Unknown = 0, // Unknown + Float16, // 16-bit floating-point + Float32, // 32-bit floating-point + Int8, // 8-bit integer + Int16, // 16-bit integer + Int32 // 32 bit integer + }; + + // Layout is virtual concept, eg: 16bit and 32bit for matrixC will share the same layout initially. + // It will be passed as the argument of getTypeProperties to calculate the more detailed layout information. + enum CooperativeMatrixLayout { + FactorMatrixLayout = 0, // A/B layout on gfx10/gfx11 + AccumulatorMatrixLayout, // C/D layout on gfx11 + Gfx10AccumulatorMatrixLayout, // 32bit@C/D layout on gfx10 + Gfx10Accumulator16bitMatrixLayout, // 16bit@C/D layout on gfx10 + InvalidLayout + }; + + // The cooperative matrix arithmetic operations the builder can consume. + // NOTE: We rely on casting this implicitly to an integer, so we cannot use an enum class. + enum class CooperativeMatrixArithOp { + IAdd = 0, + FAdd, + ISub, + FSub, + IMul, + FMul, + UDiv, + SDiv, + FDiv, + UMod, + SRem, + SMod, + FRem, + FMod + }; + + // Convert the element type enum into the corresponding LLVM type. + llvm::Type *transCooperativeMatrixElementType(CooperativeMatrixElementType elemType); + + // Get the LGC type of a cooperative matrix with the given element type and layout. + llvm::Type *getCooperativeMatrixTy(CooperativeMatrixElementType elemType, CooperativeMatrixLayout layout); + + // Determine the "length" of a cooperative matrix for purposes of extract/insert operations. + llvm::Value *CreateCooperativeMatrixLength(CooperativeMatrixElementType elemType, CooperativeMatrixLayout layout, + const llvm::Twine &instName = ""); + + // Create an "extractelement"-equivalent operation for a cooperative matrix value. + llvm::Value *CreateCooperativeMatrixExtract(llvm::Value *matrix, llvm::Value *index, + CooperativeMatrixElementType elemType, CooperativeMatrixLayout layout, + const llvm::Twine &instName = ""); + + // Create an "insertelement"-equivalent operation for a cooperative matrix value. + llvm::Value *CreateCooperativeMatrixInsert(llvm::Value *matrix, llvm::Value *value, llvm::Value *index, + CooperativeMatrixElementType elemType, CooperativeMatrixLayout layout, + const llvm::Twine &instName = ""); + + // Create cooperative matrix load. + // + // @param pointer : The pointer to a data array. + // @param stride : The number of elements in the array in memory between the first component of consecutive rows (or + // columns) in the result. + // @param colMaj : Whether the values loaded from memory are arrayed in column-major or row-major. + // @param layout : Identify it's factor or accumulator + // @param memoryAccess : Parsed from Memory operands in SPIRV-reader + // @param instName : Name to give instruction(s) + llvm::Value *CreateCooperativeMatrixLoad(llvm::Value *pointer, llvm::Value *stride, bool colMajor, + CooperativeMatrixElementType elemType, CooperativeMatrixLayout layout, + unsigned memoryAccess, const llvm::Twine &instName = ""); + + // Create cooperative matrix store. + // + // @param pointer : The pointer to a data array. + // @param matrix : The cooperative matrix to store. + // @param stride : The number of elements in the array in memory between the first component of consecutive rows (or + // columns) in the result. + // @param colMaj : Whether the values loaded from memory are arrayed in column-major or row-major. + // @param layout : Identify it's factor or accumulator + // @param memoryAccess : Parsed from Memory operands in SPIRV-reader + // @param instName : Name to give instruction(s). + llvm::Value *CreateCooperativeMatrixStore(llvm::Value *pointer, llvm::Value *matrix, llvm::Value *stride, + bool colMajor, CooperativeMatrixElementType elemType, + CooperativeMatrixLayout layout, unsigned memoryAccess, + const llvm::Twine &instName = ""); + + // Create cooperative matrix conversion. + // @param opCode : The convert opCode. + // @param source : The source cooperative matrix. + // @param dest : The conversion target. + // @param instName : Name to give instruction(s). + llvm::CallInst *CreateCooperativeMatrixConvert(llvm::CastInst::CastOps opCode, llvm::Value *source, + CooperativeMatrixElementType srcElemType, + CooperativeMatrixElementType dstElemType, + CooperativeMatrixLayout srcLayout, CooperativeMatrixLayout dstLayout, + const llvm::Twine &instName = ""); + + // Create cooperative matrix binary operation + // + // @param coopMatArithOp : The cooperative matrix arithmetic operation to perform. + // @param operand1 : The first operand. + // @param operand2 : The second operand. + // @param instName : Name to give instruction(s). + llvm::Value *CreateCooperativeMatrixBinaryOp(CooperativeMatrixArithOp coopMatArithOp, llvm::Value *lhs, + llvm::Value *rhs, CooperativeMatrixElementType elemType, + CooperativeMatrixLayout layout, const llvm::Twine &instName = ""); + + // Create cooperative MatrixTimesScalar binary operation + // + // @param matrix : It should be cooperative matrix. + // @param scalar : It should be scalar type. + // @param elemType : Name to give instruction(s). + // @param layout : Identify A/B matrices or C/D matrices. + llvm::Value *CreateCoopMatrixTimesScalar(llvm::Value *matrix, llvm::Value *scalar, + CooperativeMatrixElementType elemType, CooperativeMatrixLayout layout, + const llvm::Twine &instName = ""); + + // ===================================================================================================================== + // Create cooperative matrix transpose operation + // + // @param matrix : The first operand and it should be a cooperative matrix. + // @param elemType : The component type of the matrix. + // @param srcLayout : Identify whether it's A/B or C/D + llvm::CallInst *CreateCooperativeMatrixTranspose(llvm::Value *matrix, CooperativeMatrixElementType elemType, + CooperativeMatrixLayout srcLayout, const llvm::Twine &instName = ""); + + // Create cooperative matrix muladd operation + // @param coopMatrixa : Factor cooperative matrix. + // @param coopMatrixb : Factor cooperative matrix. + // @param coopMatrixc : Accumulator cooperative matrix. + // @param isSignedA : Identify the signess for matrix A's element type + // @param isSignedB : Identify the signess for matrix B's element type + // @param isSat : SaturatingAccumulation for calculation + // @param accumElemType : The component type of the matrix c + // @param factorElemType : The component type of the matrix a + llvm::Value *CreateCooperativeMatrixMulAdd(llvm::Value *coopMatrixa, llvm::Value *coopMatrixb, + llvm::Value *coopMatrixc, bool isSignedA, bool isSignedB, bool isSat, + CooperativeMatrixElementType accumElemType, + CooperativeMatrixElementType factorElemType, + const llvm::Twine &instName = ""); }; } // namespace lgc diff --git a/lgc/interface/lgc/Pipeline.h b/lgc/interface/lgc/Pipeline.h index 9494e47098..75f4335340 100644 --- a/lgc/interface/lgc/Pipeline.h +++ b/lgc/interface/lgc/Pipeline.h @@ -111,6 +111,12 @@ enum class RayTracingIndirectMode : unsigned { Continuations = 3, // Continuations flow that based on LowerRaytracingPipeline pass }; +// Enumerate feature flags for CPS. +enum CpsFlag : unsigned { + CpsNoFlag = 0, + CpsFlagStackInGlobalMem = 1 << 0, // Put stack in global memory instead of scratch. +}; + // Value for shadowDescriptorTable pipeline option. static const unsigned ShadowDescriptorTableDisable = ~0U; @@ -174,6 +180,8 @@ union Options { unsigned reserved20; RayTracingIndirectMode rtIndirectMode; // Ray tracing indirect mode bool enablePrimGeneratedQuery; // Whether to enable primitive generated counter + bool enableFragColor; // If enabled, do frag color broadcast + unsigned cpsFlags; // CPS feature flags }; }; static_assert(sizeof(Options) == sizeof(Options::u32All)); @@ -388,7 +396,8 @@ enum BufDataFormat { BufDataFormat5_6_5_1_Bgra, BufDataFormat1_5_6_5, BufDataFormat5_9_9_9, - BufDataFormat8_A + BufDataFormat8_A, + BufDataFormat16_16_16 }; // Numeric format of vertex buffer entry. These match the GFX9 hardware encoding. @@ -407,14 +416,10 @@ enum BufNumFormat { BufNumFormatOther, }; -// Rate of vertex input. This encodes both the "rate" (none/vertex/instance), and, for "instance", -// the divisor that determines how many instances share the same vertex buffer element. +// Rate of vertex input enum VertexInputRate { - VertexInputRateVertex = ~0, // Vertex buffer has one element per vertex - VertexInputRateNone = 0, // Vertex buffer has one element shared between all instances + VertexInputRateVertex = 0, // Vertex buffer has one element per vertex VertexInputRateInstance = 1, // Vertex buffer has one element per instance - // Other value N means vertex buffer has one element per N instances; - // N is the divisor. }; // Structure for a vertex input @@ -429,6 +434,7 @@ struct VertexInputDescription { BufDataFormat dfmt; // Data format of input; one of the BufDataFormat* values BufNumFormat nfmt; // Numeric format of input; one of the BufNumFormat* values unsigned inputRate; // Vertex input rate for the binding + unsigned divisor; // Instance divisor }; // Represents assistant info for each vertex attribute in uber fetch shader diff --git a/lgc/patch/CombineCooperativeMatrix.cpp b/lgc/patch/CombineCooperativeMatrix.cpp new file mode 100644 index 0000000000..fbf3588677 --- /dev/null +++ b/lgc/patch/CombineCooperativeMatrix.cpp @@ -0,0 +1,616 @@ +/* + *********************************************************************************************************************** + * + * Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + **********************************************************************************************************************/ +/** + *********************************************************************************************************************** + * @file CombineCooperativeMatrix.cpp + * @brief Pass and helpers for combining cooperative matrix operations. + * + * This pass is the place for combining / optimizing high-level cooperative matrix ops (@lgc.cooperative.matrix.*). + * + * In particular, this pass reduces the number of transpose and convert operations. + *********************************************************************************************************************** + */ +#include "lgc/patch/CombineCooperativeMatrix.h" +#include "lgc/Builder.h" +#include "lgc/state/Defs.h" +#include + +#define DEBUG_TYPE "lgc-combine-cooperative-matrix" + +using namespace llvm; +using namespace lgc; + +namespace { + +struct Shape { + Builder::CooperativeMatrixElementType elementType; + Builder::CooperativeMatrixLayout layout; + + Shape(Builder::CooperativeMatrixElementType elementType_, Builder::CooperativeMatrixLayout layout_) + : elementType(elementType_), layout(layout_) {} + + bool operator==(const Shape &rhs) const { return elementType == rhs.elementType && layout == rhs.layout; } +}; + +// A component of the data flow graph that starts at inputs (definitions by operations and function arguments) +// and ends at uses of the value. There are no operations inside the component, but there can be arbitrarily complex +// networks of phi nodes. +struct DataFlowComponent { + SmallVector inputs; + SmallVector phis; + SmallVector outputs; + std::optional shape; +}; + +class CooperativeMatrixCombiner { +public: + CooperativeMatrixCombiner(Function &function) : m_function(function), b(function.getContext()) {} + + bool run(); + +private: + Shape getShapeOfTranspose(CallInst *transpose); + void foldTo(Value *from, Value *to); + bool tryFold(CallInst *op); + bool tryFoldComponentContaining(Value *start); + + Function &m_function; + BuilderCommon b; + std::vector m_eraseList; +}; + +} // anonymous namespace + +// ===================================================================================================================== +// Run the combiner. +// +// @returns : True if the function was modified by the transformation and false otherwise +bool CooperativeMatrixCombiner::run() { + LLVM_DEBUG(dbgs() << "Running the cooperative matrix combiner on " << m_function.getName() << '\n'); + + bool changed = false; + + // Step 1: Collect transposes and converts + std::vector ops; + + for (Function &fn : m_function.getParent()->functions()) { + if (!fn.isDeclaration()) + continue; + + if (fn.getName().startswith(lgcName::CooperativeMatrixTranspose)) { + for (User *user : fn.users()) { + if (auto *call = dyn_cast(user)) { + if (call->getFunction() == &m_function) + ops.push_back(call); + } + } + } else if (fn.getName().startswith(lgcName::CooperativeMatrixConvert)) { + for (User *user : fn.users()) { + if (auto *call = dyn_cast(user)) { + if (call->getFunction() == &m_function) + ops.push_back(call); + } + } + } + } + + // Step 2: Attempt folds. + for (const WeakVH &handle : ops) { + auto *op = cast_or_null(handle); + if (!op) + continue; + + if (tryFold(op)) { + changed = true; + + for (Instruction *inst : llvm::reverse(m_eraseList)) { + if (inst->use_empty()) + inst->eraseFromParent(); + } + m_eraseList.clear(); + } + } + ops.clear(); + + return changed; +} + +// ===================================================================================================================== +// Determine the shape of the given transpose operation. +// +// @param [in] transpose : the transpose operation +// @returns : the cooperative matrix shape +Shape CooperativeMatrixCombiner::getShapeOfTranspose(CallInst *transpose) { + unsigned elemType = cast(transpose->getArgOperand(1))->getZExtValue(); + unsigned layout = cast(transpose->getArgOperand(2))->getZExtValue(); + return {(Builder::CooperativeMatrixElementType)elemType, (Builder::CooperativeMatrixLayout)layout}; +} + +// ===================================================================================================================== +// Replace all uses of @p from with @p to. +// +// This method queues @p from for possible deletion, but will _not_ delete it immediately. Deletion is deferred to the +// main combiner loop. +// +// Note: This is a separate method since we may eventually add related operations back to a worklist for iterative +// folding, but this is currently not implemented. +// +// @param [in] from : the value to be replaced +// @param [out] to : the replacement value +void CooperativeMatrixCombiner::foldTo(Value *from, Value *to) { + from->replaceAllUsesWith(to); + + if (auto *fromInst = dyn_cast(from)) + m_eraseList.push_back(fromInst); +} + +// ===================================================================================================================== +// Try to fold / combine around a given transpose or convert operation. +// +// @param [in] op : the operation to try to fold +// @returns : whether a change was made +bool CooperativeMatrixCombiner::tryFold(CallInst *op) { + Value *src; + bool isConvert = false; + if (op->getCalledFunction()->getName().startswith(lgcName::CooperativeMatrixConvert)) { + src = op->getArgOperand(1); + isConvert = true; + } else { + assert(op->getCalledFunction()->getName().startswith(lgcName::CooperativeMatrixTranspose)); + src = op->getArgOperand(0); + } + + if (auto *constant = dyn_cast(src)) { + if (isa(constant)) { + // tranpose/convert(poison) -> poison + foldTo(op, PoisonValue::get(op->getType())); + return true; + } + if (isa(constant)) { + // transpose/convert(undef) -> undef, if legal + bool isFoldable = true; + if (isConvert) { + auto srcElementType = + (Builder::CooperativeMatrixElementType)cast(op->getArgOperand(2))->getZExtValue(); + auto dstElementType = + (Builder::CooperativeMatrixElementType)cast(op->getArgOperand(3))->getZExtValue(); + if (srcElementType != dstElementType) { + // This is slightly conservative, but the point here is that e.g. `zext undef(i16) to i32` can't be folded + // to undef because the result can't truly take all possible bit patterns. + isFoldable = false; + } + } + + if (isFoldable) { + foldTo(op, UndefValue::get(op->getType())); + return true; + } + } + if (constant->isNullValue()) { + // transpose/convert(zeroinitializer) -> zeroinitializer + foldTo(op, Constant::getNullValue(op->getType())); + return true; + } + } else if (auto *inst = dyn_cast(src)) { + if (tryFoldComponentContaining(inst)) + return true; + } + + if (tryFoldComponentContaining(op)) + return true; + + return false; +} + +// ===================================================================================================================== +// Discover the data flow component involving @p start and try to fold it. +// +// @param [in] start : the starting value for component discovery +// @returns : whether a change was made +bool CooperativeMatrixCombiner::tryFoldComponentContaining(Value *start) { + LLVM_DEBUG(dbgs() << "tryFoldComponentContaining: " << *start << '\n'); + + assert(!isa(start)); + + // Step 1: Discover the component + DataFlowComponent component; + SmallVector worklist; + + if (auto *phi = dyn_cast(start)) + component.phis.push_back(phi); + else + component.inputs.push_back(start); + worklist.push_back(start); + + do { + Value *current = worklist.pop_back_val(); + + auto foundPhi = [&](PHINode *phi) { + if (llvm::any_of(component.phis, [=](auto elem) { return elem == phi; })) + return; + component.phis.push_back(phi); + worklist.push_back(phi); + }; + + for (Use &use : current->uses()) { + if (auto *phi = dyn_cast(use.getUser())) { + foundPhi(phi); + continue; + } + + component.outputs.push_back(&use); + } + + if (auto *phi = dyn_cast(current)) { + for (Value *incoming : phi->incoming_values()) { + if (auto *parentPhi = dyn_cast(incoming)) { + foundPhi(parentPhi); + } else { + if (llvm::any_of(component.inputs, [=](auto elem) { return elem == incoming; })) + continue; + if (!isa(incoming)) { + component.inputs.push_back(incoming); + worklist.push_back(incoming); + } + } + } + } + } while (!worklist.empty()); + + // Step 2: Analyze the inputs and outputs. + std::optional otherLayout; + Type *otherType = nullptr; + unsigned numUnhandledInputs = 0; + unsigned numTransposeInputs = 0; + unsigned numRelayoutInputs = 0; + DenseSet unhandledOutputs; + DenseSet transposeOutputs; + DenseSet relayoutOutputs; + + auto foundComponentShape = [&](Shape shape) { + if (!component.shape) + component.shape = shape; + else + assert(*component.shape == shape); + }; + + auto foundOtherLayout = [&](Builder::CooperativeMatrixLayout layout, Type *type) { + if (!otherLayout) { + otherLayout = layout; + otherType = type; + } else { + assert(*otherLayout == layout); + assert(otherType == type); + } + }; + + for (Value *input : component.inputs) { + if (auto *constant = dyn_cast(input)) { + if (!constant->isNullValue() && !isa(constant) && !isa(constant)) { + // We could try to rewrite other constants, or insert transpose/convert operations as required, but we're + // quite unlikely to encounter this in the first place, so let's not bother with the complexity. + LLVM_DEBUG(dbgs() << " bail out due to unhandled constant: " << *input << '\n'); + return false; + } + + continue; + } + + if (auto *call = dyn_cast(input)) { + if (auto *callee = call->getCalledFunction()) { + if (callee->getName().startswith(lgcName::CooperativeMatrixLoad)) + continue; // loads can be adjusted at zero cost + if (callee->getName().startswith(lgcName::CooperativeMatrixTranspose)) { + foundComponentShape(getShapeOfTranspose(call)); + ++numTransposeInputs; + continue; + } + if (callee->getName().startswith(lgcName::CooperativeMatrixConvert)) { + auto srcElemType = + (Builder::CooperativeMatrixElementType)cast(call->getArgOperand(2))->getZExtValue(); + auto dstElemType = + (Builder::CooperativeMatrixElementType)cast(call->getArgOperand(3))->getZExtValue(); + if (srcElemType != dstElemType) { + LLVM_DEBUG(dbgs() << " unhandled element type input conversion: " << *call << '\n'); + ++numUnhandledInputs; + continue; + } + + auto srcLayout = (Builder::CooperativeMatrixLayout)cast(call->getArgOperand(4))->getZExtValue(); + auto dstLayout = (Builder::CooperativeMatrixLayout)cast(call->getArgOperand(5))->getZExtValue(); + foundComponentShape({dstElemType, dstLayout}); + foundOtherLayout(srcLayout, call->getArgOperand(1)->getType()); + + ++numRelayoutInputs; + continue; + } + } + ++numUnhandledInputs; + continue; + } + + ++numUnhandledInputs; + } + + for (Use *use : component.outputs) { + if (auto *call = dyn_cast(use->getUser())) { + if (auto *callee = call->getCalledFunction()) { + if (callee->getName().startswith(lgcName::CooperativeMatrixStore)) + continue; // stores can be adapted at zero cost + if (callee->getName().startswith(lgcName::CooperativeMatrixTranspose)) { + foundComponentShape(getShapeOfTranspose(call)); + transposeOutputs.insert(use->get()); + continue; + } + if (callee->getName().startswith(lgcName::CooperativeMatrixConvert)) { + auto srcElemType = + (Builder::CooperativeMatrixElementType)cast(call->getArgOperand(2))->getZExtValue(); + auto dstElemType = + (Builder::CooperativeMatrixElementType)cast(call->getArgOperand(3))->getZExtValue(); + if (srcElemType != dstElemType) { + LLVM_DEBUG(dbgs() << " unhandled element type output conversion: " << *call << '\n'); + ++numUnhandledInputs; + continue; + } + + auto srcLayout = (Builder::CooperativeMatrixLayout)cast(call->getArgOperand(4))->getZExtValue(); + auto dstLayout = (Builder::CooperativeMatrixLayout)cast(call->getArgOperand(5))->getZExtValue(); + foundComponentShape({srcElemType, srcLayout}); + foundOtherLayout(dstLayout, call->getType()); + + relayoutOutputs.insert(use->get()); + continue; + } + } + } + + unhandledOutputs.insert(use->get()); + } + + // Step 3: Transpose the component if that is beneficial. + int transposeCost = -(numTransposeInputs + transposeOutputs.size()); + transposeCost += numUnhandledInputs + numRelayoutInputs + unhandledOutputs.size() + relayoutOutputs.size(); + + LLVM_DEBUG(dbgs() << " transpose cost delta: " << transposeCost << '\n'); + + if (transposeCost < 0) { + // Cache for newly inserted transpose operations. + DenseMap outTransposed; + + for (Value *input : component.inputs) { + // Handle inputs that can be folded away / absorbed. + if (auto *call = dyn_cast(input)) { + if (auto *callee = call->getCalledFunction()) { + if (callee->getName().startswith(lgcName::CooperativeMatrixTranspose)) { + Value *src = call->getArgOperand(0); + foldTo(input, src); + + // Prepopulate the transpose cache to re-use the old transpose operation instead of creating a new one. + outTransposed.try_emplace(src, input); + continue; + } + if (callee->getName().startswith(lgcName::CooperativeMatrixLoad)) { + bool colMajor = cast(call->getArgOperand(2))->getZExtValue(); + call->setArgOperand(2, b.getInt1(!colMajor)); + continue; + } + } + } + + // Handle generic inputs that need to be transposed explicitly. + if (auto *inst = dyn_cast(input)) { + b.SetInsertPoint(inst->getNextNode()); + } else { + assert(isa(input)); + b.SetInsertPointPastAllocas(&m_function); + } + + auto *transposed = b.CreateCooperativeMatrixTranspose(PoisonValue::get(input->getType()), + component.shape->elementType, component.shape->layout); + foldTo(input, transposed); + transposed->setArgOperand(0, input); + } + + for (Use *use : component.outputs) { + // Handle outputs that can be folded away / absorbed. + if (auto *call = dyn_cast(use->getUser())) { + if (auto *callee = call->getCalledFunction()) { + if (callee->getName().startswith(lgcName::CooperativeMatrixTranspose)) { + foldTo(call, use->get()); + continue; + } + if (callee->getName().startswith(lgcName::CooperativeMatrixStore)) { + bool colMajor = cast(call->getArgOperand(2))->getZExtValue(); + call->setArgOperand(2, b.getInt1(!colMajor)); + continue; + } + } + } + + // Handle generic outputs that need to be transposed explicitly. + Value *&transposed = outTransposed[use->get()]; + if (!transposed) { + if (auto *phi = cast(use->get())) { + b.SetInsertPoint(phi->getParent(), phi->getParent()->getFirstInsertionPt()); + } else { + auto *def = cast(use->get()); + b.SetInsertPoint(def->getNextNode()); + } + + transposed = + b.CreateCooperativeMatrixTranspose(use->get(), component.shape->elementType, component.shape->layout); + } + + use->set(transposed); + } + + return true; + } + + // Step 4: Otherwise, relayout the component if that is beneficial. + int relayoutCost = -(numRelayoutInputs + relayoutOutputs.size()); + relayoutCost += numUnhandledInputs + numTransposeInputs + unhandledOutputs.size() + transposeOutputs.size(); + + LLVM_DEBUG(dbgs() << " relayout cost delta: " << relayoutCost << '\n'); + + if (relayoutCost < 0) { + // Cache for newly inserted relayout convert operations. + DenseMap outRelayouted; + + // Force-override phi types if necessary + if (!component.phis.empty() && component.phis[0]->getType() != otherType) { + for (PHINode *phi : component.phis) { + phi->mutateType(otherType); + + for (Use &use : phi->incoming_values()) { + if (auto *constant = dyn_cast(use.get())) { + if (constant->isNullValue()) { + use.set(Constant::getNullValue(otherType)); + } else if (isa(constant)) { + use.set(UndefValue::get(otherType)); + } else if (isa(constant)) { + use.set(PoisonValue::get(otherType)); + } else { + // We should have bailed out earlier in this case. + llvm_unreachable("unhandled constant in cooperative matrix phi"); + } + } + } + } + } + + for (Value *input : component.inputs) { + // Handle inputs for which the relayout can be folded or absorbed. + if (auto *call = dyn_cast(input)) { + if (auto *callee = call->getCalledFunction()) { + if (callee->getName().startswith(lgcName::CooperativeMatrixConvert)) { + unsigned srcElemType = cast(call->getArgOperand(2))->getZExtValue(); + unsigned dstElemType = cast(call->getArgOperand(3))->getZExtValue(); + + if (srcElemType == dstElemType) { + unsigned srcLayout = + (Builder::CooperativeMatrixLayout)cast(call->getArgOperand(4))->getZExtValue(); + assert(srcLayout == *otherLayout); + (void(srcLayout)); // unused + + Value *src = call->getArgOperand(1); + foldTo(input, src); + + // Pre-populate the cache to re-use the relayout operation instead of creating a new one. + outRelayouted.try_emplace(src, input); + continue; + } + + // Integrate the relayouting into a merged conversion op. + call->setArgOperand(5, b.getInt32((unsigned)*otherLayout)); + continue; + } + if (callee->getName().startswith(lgcName::CooperativeMatrixLoad)) { + call->setArgOperand(4, b.getInt32((unsigned)*otherLayout)); + continue; + } + } + } + + // Handle generic inputs that need a new convert operation inserted. + if (auto *inst = dyn_cast(input)) { + b.SetInsertPoint(inst->getNextNode()); + } else { + assert(isa(input)); + b.SetInsertPointPastAllocas(&m_function); + } + + CallInst *convert = b.CreateCooperativeMatrixConvert((CastInst::CastOps)0, PoisonValue::get(input->getType()), + component.shape->elementType, component.shape->elementType, + component.shape->layout, *otherLayout); + foldTo(input, convert); + convert->setArgOperand(1, input); + } + + for (Use *use : component.outputs) { + // Handle outputs for which the relayout can be folded or absorbed. + if (auto *call = dyn_cast(use->getUser())) { + if (auto *callee = call->getCalledFunction()) { + if (callee->getName().startswith(lgcName::CooperativeMatrixConvert)) { + unsigned srcElemType = cast(call->getArgOperand(2))->getZExtValue(); + unsigned dstElemType = cast(call->getArgOperand(3))->getZExtValue(); + + if (srcElemType == dstElemType) { + unsigned dstLayout = + (Builder::CooperativeMatrixLayout)cast(call->getArgOperand(5))->getZExtValue(); + assert(dstLayout == *otherLayout); + (void(dstLayout)); // unused + + foldTo(call, use->get()); + continue; + } + } + if (callee->getName().startswith(lgcName::CooperativeMatrixStore)) { + call->setArgOperand(4, b.getInt32((unsigned)*otherLayout)); + continue; + } + } + } + + // Handle generic outputs that need a new convert operation inserted. + Value *&relayouted = outRelayouted[use->get()]; + if (!relayouted) { + if (auto *phi = cast(use->get())) { + b.SetInsertPoint(phi->getParent(), phi->getParent()->getFirstInsertionPt()); + } else { + auto *def = cast(use->get()); + b.SetInsertPoint(def->getNextNode()); + } + + relayouted = + b.CreateCooperativeMatrixConvert((CastInst::CastOps)0, use->get(), component.shape->elementType, + component.shape->elementType, *otherLayout, component.shape->layout); + } + + use->set(relayouted); + } + + return true; + } + + return false; +} + +// ===================================================================================================================== +// Run the pass on a function. +// +// @param [in/out] function : LLVM function to be run on +// @param [in/out] analysisManager : Analysis manager to use for this transformation +// @returns : The preserved analyses (The Analyses that are still valid after this pass) +PreservedAnalyses CombineCooperativeMatrix::run(Function &function, FunctionAnalysisManager &analysisManager) { + CooperativeMatrixCombiner combiner{function}; + + if (combiner.run()) { + PreservedAnalyses PA; + PA.preserveSet(); + return PA; + } + return PreservedAnalyses::all(); +} diff --git a/lgc/patch/Continufy.cpp b/lgc/patch/Continufy.cpp index 8505b1096b..18dedc09b2 100644 --- a/lgc/patch/Continufy.cpp +++ b/lgc/patch/Continufy.cpp @@ -29,7 +29,9 @@ * This pass translates indirect call into cps.await call, which will be lowered into continuation call. *********************************************************************************************************************** */ + #include "lgc/patch/Continufy.h" +#include "compilerutils/CompilerUtils.h" #include "lgc/Builder.h" #include "lgc/LgcCpsDialect.h" #include "lgc/LgcDialect.h" @@ -54,7 +56,7 @@ static Function *insertCpsArguments(Function &fn) { auto *fnTy = fn.getFunctionType(); argTys.append(fnTy->params().begin(), fnTy->params().end()); - auto *newFn = mutateFunctionArguments(fn, Type::getVoidTy(context), argTys, fn.getAttributes()); + auto *newFn = CompilerUtils::mutateFunctionArguments(fn, Type::getVoidTy(context), argTys, fn.getAttributes()); fn.replaceAllUsesWith(newFn); for (unsigned idx = 0; idx < fn.arg_size(); idx++) { diff --git a/lgc/patch/FragColorExport.cpp b/lgc/patch/FragColorExport.cpp index a9f042a572..40ae489eab 100644 --- a/lgc/patch/FragColorExport.cpp +++ b/lgc/patch/FragColorExport.cpp @@ -920,15 +920,28 @@ void FragColorExport::generateExportInstructions(ArrayRef info, info = info.drop_front(1); } + // Record each color buffer's export info for broadcasting + llvm::SmallVector broadCastInfo; + if (m_pipelineState->getOptions().enableFragColor) { + auto &expInfo = info[0]; + assert(expInfo.ty != nullptr); + + for (unsigned location = 0; location < MaxColorTargets; ++location) { + if (m_pipelineState->getColorExportFormat(location).dfmt != BufDataFormatInvalid) + broadCastInfo.push_back({0, location, expInfo.isSigned, expInfo.ty}); + } + } + // Now do color exports by color buffer. unsigned hwColorExport = 0; + auto finalExpInfo = m_pipelineState->getOptions().enableFragColor ? ArrayRef(broadCastInfo) : info; for (unsigned location = 0; location < MaxColorTargets; ++location) { - auto infoIt = llvm::find_if(info, [&](const ColorExportInfo &info) { return info.location == location; }); - if (infoIt == info.end()) + auto infoIt = llvm::find_if(finalExpInfo, + [&](const ColorExportInfo &finalExpInfo) { return finalExpInfo.location == location; }); + if (infoIt == finalExpInfo.end()) continue; assert(infoIt->hwColorTarget < MaxColorTargets); - auto expFmt = static_cast(m_pipelineState->computeExportFormat(infoIt->ty, location)); unsigned channelWriteMask = m_pipelineState->getColorExportFormat(location).channelWriteMask; bool needExpInst = (expFmt != EXP_FORMAT_ZERO) && diff --git a/lgc/patch/Gfx9ConfigBuilder.cpp b/lgc/patch/Gfx9ConfigBuilder.cpp index 5213d1e222..6aa70f824e 100644 --- a/lgc/patch/Gfx9ConfigBuilder.cpp +++ b/lgc/patch/Gfx9ConfigBuilder.cpp @@ -887,10 +887,7 @@ template void ConfigBuilder::buildVsRegConfig(ShaderStage shaderSta const bool enableXfb = m_pipelineState->enableXfb(); const bool enablePrimStats = m_pipelineState->enablePrimStats(); if (shaderStage == ShaderStageCopyShader) { - // NOTE: For copy shader, usually we use fixed number of user data registers. - // But in some cases, we may change user data registers, we use variable to keep user sgpr count here - auto copyShaderUserSgprCount = lgc::CopyShaderUserSgprCount; - SET_REG_FIELD(&config->vsRegs, SPI_SHADER_PGM_RSRC2_VS, USER_SGPR, copyShaderUserSgprCount); + SET_REG_FIELD(&config->vsRegs, SPI_SHADER_PGM_RSRC2_VS, USER_SGPR, lgc::CopyShaderUserSgprCount); setNumAvailSgprs(Util::Abi::HardwareStage::Vs, m_pipelineState->getTargetInfo().getGpuProperty().maxSgprsAvailable); setNumAvailVgprs(Util::Abi::HardwareStage::Vs, m_pipelineState->getTargetInfo().getGpuProperty().maxVgprsAvailable); diff --git a/lgc/patch/LowerCooperativeMatrix.cpp b/lgc/patch/LowerCooperativeMatrix.cpp new file mode 100644 index 0000000000..4582205f70 --- /dev/null +++ b/lgc/patch/LowerCooperativeMatrix.cpp @@ -0,0 +1,1897 @@ +/* + *********************************************************************************************************************** + * + * Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + **********************************************************************************************************************/ +/** + *********************************************************************************************************************** + * @file LowerCooperativeMatrix.cpp + * @brief LGC source file : Lower CooperativeMatrix manager, and pass that uses it + *********************************************************************************************************************** + */ +#include "lgc/patch/LowerCooperativeMatrix.h" +#include "lgc/Builder.h" +#include "lgc/LgcContext.h" +#include "lgc/state/IntrinsDefs.h" +#include "lgc/state/PipelineShaders.h" +#include "lgc/state/PipelineState.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "lgc-lower-cooperative-matrix" + +using namespace llvm; +using namespace lgc; + +namespace lgc { + +// ===================================================================================================================== +// Run the patch cooperative matrix pass on a module +// +// @param [in/out] module : LLVM module to be run on +// @param [in/out] analysisManager : Analysis manager to use for this transformation +// @returns : The preserved analyses (The Analyses that are still valid after this pass) +PreservedAnalyses LowerCooperativeMatrix::run(Module &module, ModuleAnalysisManager &analysisManager) { + PipelineState *pipelineState = analysisManager.getResult(module).getPipelineState(); + PipelineShadersResult &pipelineShaders = analysisManager.getResult(module); + + if (runImpl(module, pipelineShaders, pipelineState)) { + PreservedAnalyses PA; + PA.preserveSet(); + return PA; + } + return PreservedAnalyses::all(); +} + +// ===================================================================================================================== +// Run the on a module +// +// @param [in/out] module : LLVM module to be run on +// @param pipelineState : Pipeline state +// @returns : True if the module was modified by the transformation and false otherwise +bool LowerCooperativeMatrix::runImpl(Module &module, PipelineShadersResult &pipelineShaders, + PipelineState *pipelineState) { + LLVM_DEBUG(dbgs() << "Run the pass Patch-Cooperative-Matrix\n"); + Patch::init(&module); + m_pipelineState = pipelineState; + m_pipelineShaders = &pipelineShaders; + m_shaderStage = ShaderStageCompute; + m_gfxIp = m_pipelineState->getTargetInfo().getGfxIpVersion(); + + SmallVector lowerCoopMatrixCallees; + for (auto &func : module) { + auto name = func.getName(); + if (name.startswith(lgcName::CooperativeMatrix)) + lowerCoopMatrixCallees.push_back(&func); + } + if (lowerCoopMatrixCallees.empty()) + return false; + + processCoopMatrixFunction(lowerCoopMatrixCallees); + + for (auto callInst : m_coopMatrixCalls) { + callInst->dropAllReferences(); + callInst->eraseFromParent(); + } + m_coopMatrixCalls.clear(); + return true; +} + +// ===================================================================================================================== +// Run the on a module +// +// @param coopMatrixCallees : Function array for the cooperativeMatrix +void LowerCooperativeMatrix::processCoopMatrixFunction(ArrayRef coopMatrixCallees) { + for (auto callee : coopMatrixCallees) { + for (auto user : callee->users()) { + if (CallInst *callInst = dyn_cast(user)) { + visitCallInst(*callInst); + } + } + } +} + +// ===================================================================================================================== +// Visits "call" instruction. +// +// @param callInst : "Call" instruction +void LowerCooperativeMatrix::visitCallInst(CallInst &callInst) { + auto callee = callInst.getCalledFunction(); + if (!callee) + return; + + m_coopMatrixCalls.push_back(&callInst); + + BuilderCommon builder(*m_context); + builder.SetInsertPoint(&callInst); + + auto mangledName = callee->getName(); + if (mangledName.startswith(lgcName::CooperativeMatrixLength)) { + auto layout = + static_cast(cast(callInst.getOperand(1))->getZExtValue()); + callInst.replaceAllUsesWith(builder.getInt32(getLength(layout))); + } else if (mangledName.startswith(lgcName::CooperativeMatrixExtract)) { + Value *matrix = callInst.getOperand(0); + Value *index = callInst.getOperand(1); + auto elemType = + static_cast(cast(callInst.getOperand(2))->getZExtValue()); + auto layout = + static_cast(cast(callInst.getOperand(3))->getZExtValue()); + Value *result = cooperativeMatrixExtract(builder, matrix, index, elemType, layout); + result->takeName(&callInst); + callInst.replaceAllUsesWith(result); + } else if (mangledName.startswith(lgcName::CooperativeMatrixInsert)) { + Value *matrix = callInst.getOperand(0); + Value *value = callInst.getOperand(1); + Value *index = callInst.getOperand(2); + auto elemType = + static_cast(cast(callInst.getOperand(3))->getZExtValue()); + auto layout = + static_cast(cast(callInst.getOperand(4))->getZExtValue()); + Value *result = cooperativeMatrixInsert(builder, matrix, value, index, elemType, layout); + result->takeName(&callInst); + callInst.replaceAllUsesWith(result); + } else if (mangledName.startswith(lgcName::CooperativeMatrixLoad)) { + Value *dataPtr = callInst.getOperand(0); + Value *stride = callInst.getOperand(1); + bool colMajor = cast(callInst.getOperand(2))->getZExtValue(); + Builder::CooperativeMatrixElementType elemType = + static_cast(cast(callInst.getOperand(3))->getZExtValue()); + Builder::CooperativeMatrixLayout layout = + static_cast(cast(callInst.getOperand(4))->getZExtValue()); + unsigned memoryAccess = cast(callInst.getOperand(5))->getZExtValue(); + + Value *loadVal = cooperativeMatrixLoadInternal(dataPtr, stride, colMajor, elemType, layout, memoryAccess, + callInst.getName(), &callInst); + callInst.replaceAllUsesWith(loadVal); + + } else if (mangledName.startswith(lgcName::CooperativeMatrixStore)) { + Value *dataPtr = callInst.getOperand(0); + Value *stride = callInst.getOperand(1); + bool colMajor = cast(callInst.getOperand(2))->getZExtValue(); + Builder::CooperativeMatrixElementType elemType = + static_cast(cast(callInst.getOperand(3))->getZExtValue()); + Builder::CooperativeMatrixLayout layout = + static_cast(cast(callInst.getOperand(4))->getZExtValue()); + unsigned memoryAccess = cast(callInst.getOperand(5))->getZExtValue(); + Value *vecVal = callInst.getOperand(6); + + cooperativeMatrixStoreInternal(dataPtr, stride, colMajor, elemType, layout, memoryAccess, vecVal, + callInst.getName(), &callInst); + + } else if (mangledName.startswith(lgcName::CooperativeMatrixConvert)) { + CastInst::CastOps castOp = + static_cast(cast(callInst.getOperand(0))->getZExtValue()); + Value *source = callInst.getOperand(1); + Builder::CooperativeMatrixElementType srcElemType = + static_cast(cast(callInst.getOperand(2))->getZExtValue()); + Builder::CooperativeMatrixElementType dstElemType = + static_cast(cast(callInst.getOperand(3))->getZExtValue()); + Builder::CooperativeMatrixLayout srcLayout = + static_cast(cast(callInst.getOperand(4))->getZExtValue()); + Builder::CooperativeMatrixLayout dstLayout = + static_cast(cast(callInst.getOperand(5))->getZExtValue()); + Value *resultVal = cooperativeMatrixConvert(castOp, source, srcElemType, dstElemType, srcLayout, dstLayout, + callInst.getName(), &callInst); + if ((cast(resultVal->getType())->getNumElements() == 4) && + (dstLayout == Builder::CooperativeMatrixLayout::AccumulatorMatrixLayout || + dstLayout == Builder::CooperativeMatrixLayout::Gfx10Accumulator16bitMatrixLayout || + dstLayout == Builder::CooperativeMatrixLayout::Gfx10AccumulatorMatrixLayout)) { + // for wave64 needs shuffleVector from V4 to V8 as frontend will always recognize V8 not care wave32 or wave64 + resultVal = builder.CreateShuffleVector(resultVal, PoisonValue::get(resultVal->getType()), + ArrayRef{0, 1, 2, 3, 4, 5, 6, 7}); + } + callInst.replaceAllUsesWith(resultVal); + + } else if (mangledName.startswith(lgcName::CooperativeMatrixTranspose)) { + Value *matrix = callInst.getOperand(0); + Builder::CooperativeMatrixElementType elemType = + static_cast(cast(callInst.getOperand(1))->getZExtValue()); + Builder::CooperativeMatrixLayout srcLayout = + static_cast(cast(callInst.getOperand(2))->getZExtValue()); + + Value *resultVal = cooperativeMatrixTranspose(matrix, elemType, srcLayout, callInst.getName(), &callInst); + callInst.replaceAllUsesWith(resultVal); + + } else if (mangledName.startswith(lgcName::CooperativeMatrixBinOp)) { + Builder::CooperativeMatrixArithOp coopMatArithOp = + static_cast(cast(callInst.getOperand(0))->getZExtValue()); + Value *lhs = callInst.getOperand(1); + Value *rhs = callInst.getOperand(2); + Builder::CooperativeMatrixElementType elemType = + static_cast(cast(callInst.getOperand(3))->getZExtValue()); + Builder::CooperativeMatrixLayout srcLayout = + static_cast(cast(callInst.getOperand(4))->getZExtValue()); + + Value *resultVal = + cooperativeMatrixBinaryOp(coopMatArithOp, lhs, rhs, elemType, srcLayout, callInst.getName(), &callInst); + callInst.replaceAllUsesWith(resultVal); + + } else if (mangledName.startswith(lgcName::CooperativeMatrixTimesScalar)) { + Value *matrix = callInst.getOperand(0); + Value *scalar = callInst.getOperand(1); + Builder::CooperativeMatrixElementType elemType = + static_cast(cast(callInst.getOperand(2))->getZExtValue()); + Builder::CooperativeMatrixLayout srcLayout = + static_cast(cast(callInst.getOperand(3))->getZExtValue()); + + Value *resultVal = coopMatrixTimesScalar(matrix, scalar, elemType, srcLayout, callInst.getName(), &callInst); + callInst.replaceAllUsesWith(resultVal); + + } else if (mangledName.startswith(lgcName::CooperativeMatrixMulAdd)) { + Value *matrixA = callInst.getOperand(0); + Value *matrixB = callInst.getOperand(1); + Value *matrixC = callInst.getOperand(2); + bool isSignedA = cast(callInst.getOperand(3))->getZExtValue(); + bool isSignedB = cast(callInst.getOperand(4))->getZExtValue(); + bool isSat = cast(callInst.getOperand(5))->getZExtValue(); + Builder::CooperativeMatrixElementType accumElemType = + static_cast(cast(callInst.getOperand(6))->getZExtValue()); + Builder::CooperativeMatrixElementType factorElemType = + static_cast(cast(callInst.getOperand(7))->getZExtValue()); + Value *resultVal = cooperativeMatrixMulAdd(matrixA, matrixB, matrixC, isSignedA, isSignedB, isSat, accumElemType, + factorElemType, callInst.getName(), &callInst); + callInst.replaceAllUsesWith(resultVal); + + } else { + llvm_unreachable("Should never be called!"); + } +} + +// ===================================================================================================================== +// Get the "length" of a matrix of the given layout, i.e. the number of matrix components stored per lane. +// +// @param layout : the matrix layout +unsigned LowerCooperativeMatrix::getLength(Builder::CooperativeMatrixLayout layout) const { + auto waveSize = m_pipelineState->getShaderWaveSize(m_shaderStage); + switch (layout) { + case BuilderCommon::FactorMatrixLayout: + return 16; + case BuilderCommon::AccumulatorMatrixLayout: { + return waveSize == 32 ? 8 : 4; + } + case BuilderCommon::Gfx10AccumulatorMatrixLayout: + case BuilderCommon::Gfx10Accumulator16bitMatrixLayout: + return 8; + default: + llvm_unreachable("unhandled matrix layout"); + } +} + +// ===================================================================================================================== +// Determine properties of the cooperative matrix type depending on element type, layout, and wave size. +// +// @param elemType : the matrix element type +// @param layout : the matrix layout +// @returns : the type properties +LowerCooperativeMatrix::TypeProperties +LowerCooperativeMatrix::getTypeProperties(Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout) const { + TypeProperties props; + + props.matrixElementStride = 1; + + switch (elemType) { + case Builder::CooperativeMatrixElementType::Float32: + case Builder::CooperativeMatrixElementType::Int32: + props.numMatrixElements = 8; + props.numMatrixWords = 8; + break; + case Builder::CooperativeMatrixElementType::Float16: + case Builder::CooperativeMatrixElementType::Int16: + props.numMatrixElements = 16; + props.numMatrixWords = 8; + break; + case Builder::CooperativeMatrixElementType::Int8: + props.numMatrixElements = 16; + props.numMatrixWords = 4; + break; + default: + llvm_unreachable("unknown element type"); + } + + auto waveSize = m_pipelineState->getShaderWaveSize(m_shaderStage); + if (layout == Builder::CooperativeMatrixLayout::FactorMatrixLayout) { + assert(elemType != Builder::CooperativeMatrixElementType::Float32 && + elemType != Builder::CooperativeMatrixElementType::Int32); + props.numFlatElements = 16; + } else if (layout == Builder::CooperativeMatrixLayout::AccumulatorMatrixLayout) { + props.numFlatElements = waveSize == 32 ? 8 : 4; + if (elemType == Builder::CooperativeMatrixElementType::Float16 || + elemType == Builder::CooperativeMatrixElementType::Int16) { + props.matrixElementStride = 2; + } + } else if (layout == Builder::CooperativeMatrixLayout::Gfx10AccumulatorMatrixLayout || + layout == Builder::CooperativeMatrixLayout::Gfx10Accumulator16bitMatrixLayout) { + props.numFlatElements = 8; + } else { + llvm_unreachable("Unsupported layout!"); + } + + return props; +} + +// ===================================================================================================================== +// Create cooperative Matrix data(C/D:V8/V4 A/B: V8/V4) from vector value(C/D wave32:V8 wave64:V4 A/B: V16) +// +// @param builder : the builder to use +// @param vecValue : Vector Value which maybe V16. +// @param elemType : Element type for the matrix. +// @param layout : Identify whether this matrix is A/B or C/D +Value *LowerCooperativeMatrix::convFlatVecToCoopMatrixVec(BuilderCommon &builder, Value *vecValue, + Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout) { + auto props = getTypeProperties(elemType, layout); + + if (props.numMatrixElements > props.numFlatElements) { + SmallVector mask; + for (unsigned i = 0; i < props.numMatrixElements / props.matrixElementStride; ++i) { + mask.push_back(i); + for (unsigned j = 1; j < props.matrixElementStride; ++j) + mask.push_back(-1); + } + vecValue = builder.CreateShuffleVector(vecValue, PoisonValue::get(vecValue->getType()), mask); + } + + Type *wordTy = vecValue->getType()->isIntOrIntVectorTy() ? builder.getInt32Ty() : builder.getFloatTy(); + return builder.CreateBitCast(vecValue, FixedVectorType::get(wordTy, props.numMatrixWords)); +} + +// ===================================================================================================================== +// Create vector value(C/D wave32:V8 wave64:V4 A/B: V16) from cooperative Matrix data(C/D:V8/V4 A/B: V8/V4) +// +// @param builder : the builder to use +// @param matrixValue : Vector Value which maybe V16. +// @param elemType : Element type for the matrix. +// @param layout : Identify whether this matrix is A/B or C/D +Value *LowerCooperativeMatrix::convCoopMatrixVecToFlatVec(BuilderCommon &builder, Value *matrixValue, + Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout) { + auto props = getTypeProperties(elemType, layout); + + Type *flatType = FixedVectorType::get(builder.transCooperativeMatrixElementType(elemType), props.numMatrixElements); + Value *tmp = builder.CreateBitCast(matrixValue, flatType); + + if (props.numFlatElements < props.numMatrixElements) { + SmallVector mask; + for (unsigned i = 0; i < props.numFlatElements; ++i) + mask.push_back(i * props.matrixElementStride); + tmp = builder.CreateShuffleVector(tmp, PoisonValue::get(tmp->getType()), mask); + } + + return tmp; +} + +// ===================================================================================================================== +// Load contiguous elements from the specified location of the memory. +// @param layout : This is identify for factor(A/B) or accumulator(C) for 16 bit element matrix. +// @param elemType : The element type for the matrix. +// @param waveSize : Identify it's in wave32 or wave64. +// @param stride : The stride in bytes in memory between the first elements of consecutive rows (orcolumns) in the +// source data. Guaranteed to be a multiple of the matrix element size. +// @param isColMajor : Identify the order for the data stored in memory, col-major/row-major +// @param insertPos : Where to insert the instruction +LowerCooperativeMatrix::ComputeAddressInfo +LowerCooperativeMatrix::computeAddressing(Builder::CooperativeMatrixLayout layout, + Builder::CooperativeMatrixElementType elemType, int waveSize, Value *stride, + bool isColMajor, Instruction *insertPos) { + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + Value *threadId = getLaneNumber(builder); + ComputeAddressInfo addrInfo; + Value *rowOffsetInFirstVgpr = nullptr; + Value *colOffsetPerLane = builder.CreateSRem(threadId, builder.getInt32(16)); + addrInfo.microStep = builder.getInt32(0); + addrInfo.microCount = 1; + (void)elemType; + + if (layout == Builder::CooperativeMatrixLayout::FactorMatrixLayout) { + rowOffsetInFirstVgpr = builder.getInt32(0); + addrInfo.macroStep = builder.getInt32(1); + } else if (layout == Builder::CooperativeMatrixLayout::AccumulatorMatrixLayout) { + rowOffsetInFirstVgpr = builder.CreateUDiv(threadId, builder.getInt32(16)); + addrInfo.macroStep = (waveSize == 64 ? builder.getInt32(4) : builder.getInt32(2)); + } else if (layout == Builder::CooperativeMatrixLayout::Gfx10AccumulatorMatrixLayout) { + rowOffsetInFirstVgpr = builder.CreateUDiv(builder.CreateSRem(threadId, builder.getInt32(32)), builder.getInt32(16)); + addrInfo.macroStep = builder.getInt32(2); + } else if (layout == Builder::CooperativeMatrixLayout::Gfx10Accumulator16bitMatrixLayout) { + // For 16bit@Accumulator@gfx10:lane_0: {0_0,1_0,4_0,5_0,8_0,9_0,12_0,13_0} + // lane_16: {2_0,3_0,6_0,7_0,10_0,11_0,14_0,15_0} on lane_16. + Value *laneGroupIdx = builder.CreateUDiv(threadId, builder.getInt32(16)); + Value *evenGroup = builder.CreateICmpEQ(builder.CreateAnd(laneGroupIdx, builder.getInt32(1)), builder.getInt32(0)); + addrInfo.microCount = 2; + rowOffsetInFirstVgpr = builder.CreateSelect(evenGroup, builder.getInt32(0), builder.getInt32(2)); + addrInfo.macroStep = builder.getInt32(4); + addrInfo.microStep = builder.getInt32(1); + } else { + llvm_unreachable("This layout is not supported now."); + } + + if (isColMajor) { + addrInfo.base = builder.CreateAdd(rowOffsetInFirstVgpr, builder.CreateMul(colOffsetPerLane, stride)); + } else { + addrInfo.base = builder.CreateAdd(builder.CreateMul(rowOffsetInFirstVgpr, stride), colOffsetPerLane); + addrInfo.macroStep = builder.CreateMul(addrInfo.macroStep, stride); + addrInfo.microStep = builder.CreateMul(addrInfo.microStep, stride); + } + + return addrInfo; +} + +// ===================================================================================================================== +// Load contiguous elements from the specified location of the memory. +// @param dataPtr : The pointer to a data array. +// @param stride : The stride in bytes in memory between the first elements of consecutive rows (orcolumns) in the +// source data. Guaranteed to be a multiple of the matrix element size. +// @param isColMajor : Identify the order for the data stored in memory, col-major/row-major +// @param elemType : The element type for the matrix +// @param layout : This is identify for factor(A/B) or accumulator(C) for 16 bit element matrix. +// @param memoryAccess : The memory operands which provide:isVolatile/isTemporal/isCoherent +// additional operands, maybe volatile/Aligned/Nontemporal/MakePointerAvailable +// /MakePointerVisible/NonPrivatePointer usded by CooperativeMatrix Load/Store. +// @param instName : Name to give instruction(s). +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::cooperativeMatrixLoadInternal(Value *dataPtr, Value *stride, bool isColMajor, + Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout, + unsigned memoryAccess, const Twine &instName, + Instruction *insertPos) { + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + + auto waveSize = m_pipelineState->getShaderWaveSize(getShaderStage(builder.GetInsertBlock()->getParent())); + assert(waveSize == 32 || waveSize == 64); + + // Calc element offset in memory + Type *elemTy = builder.transCooperativeMatrixElementType(elemType); + const unsigned dataBitwidth = elemTy->getScalarSizeInBits(); + const unsigned addrSpace = dataPtr->getType()->getPointerAddressSpace(); + assert(addrSpace == ADDR_SPACE_LOCAL || addrSpace == ADDR_SPACE_BUFFER_FAT_POINTER || addrSpace == ADDR_SPACE_GLOBAL); + + stride = builder.CreateExactSDiv(stride, builder.getInt32(dataBitwidth / 8)); + + // calc memoryAccess + bool isVolatile = memoryAccess & Builder::MemoryAccessVolatileMask; + bool isCoherent = memoryAccess & Builder::MemoryAccessCoherentMask; + bool isTemporal = memoryAccess & Builder::MemoryAccessTemporalMask; + + auto props = getTypeProperties(elemType, layout); + auto addrInfo = computeAddressing(layout, elemType, waveSize, stride, isColMajor, insertPos); + + Value *vecVal = PoisonValue::get(FixedVectorType::get(elemTy, props.numFlatElements)); + for (unsigned idx = 0; idx < props.numFlatElements; ++idx) { + Value *offset = builder.CreateAdd( + addrInfo.base, builder.CreateMul(addrInfo.macroStep, builder.getInt32(idx / addrInfo.microCount))); + offset = + builder.CreateAdd(offset, builder.CreateMul(addrInfo.microStep, builder.getInt32(idx % addrInfo.microCount))); + + Value *elePtr = builder.CreateGEP(elemTy, dataPtr, offset); + Value *eleVal = builder.CreateLoad(elemTy, elePtr, isVolatile, instName); + if (isCoherent && !(addrSpace == ADDR_SPACE_LOCAL && dataBitwidth < 32)) + cast(eleVal)->setAtomic(AtomicOrdering::Unordered); + if (isTemporal) + cast(eleVal)->setMetadata(LLVMContext::MD_nontemporal, MDNode::get(builder.getContext(), {})); + vecVal = builder.CreateInsertElement(vecVal, eleVal, idx); + } + + Value *coMatrix = convFlatVecToCoopMatrixVec(builder, vecVal, elemType, layout); + return coMatrix; +} + +// ===================================================================================================================== +// Store a contiguous elements from the specified location of the memory. +// +// @param dataPtr : The pointer to a data array. +// @param stride : The stride in bytes between the first elements of consecutive rows (or columns) in the destination. +// Guaranteed to be a multiple of the element size. +// @param colMajor : Identify the order for the data stored in memory, col-major/row-major +// @param elemType : The type for the element. +// @param layout : This is identify for factor(A/B) or accumulator(C) for 16 bit element matrix. +// @param memoryAccess : The memory operands which provide +// additional operands, maybe volatile/Aligned/Nontemporal/MakePointerAvailable +// /MakePointerVisible/NonPrivatePointer used by CooperativeMatrix Load/Store. +// @param vecVal : The contiguous elements made up of a vector to be loaded or stored. +// @param instName : Name to give instruction(s). +// @param insertPos : Where to insert the instruction +void LowerCooperativeMatrix::cooperativeMatrixStoreInternal(Value *dataPtr, Value *stride, bool isColMajor, + Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout, + unsigned memoryAccess, Value *&vecVal, + const Twine &instName, Instruction *insertPos) { + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + + auto waveSize = m_pipelineState->getShaderWaveSize(getShaderStage(builder.GetInsertBlock()->getParent())); + assert(waveSize == 32 || waveSize == 64); + + // Calc element offset in memory + Type *elemTy = builder.transCooperativeMatrixElementType(elemType); + const unsigned dataBitwidth = elemTy->getScalarSizeInBits(); + const unsigned addrSpace = dataPtr->getType()->getPointerAddressSpace(); + assert(addrSpace == ADDR_SPACE_LOCAL || addrSpace == ADDR_SPACE_BUFFER_FAT_POINTER || addrSpace == ADDR_SPACE_GLOBAL); + + stride = builder.CreateExactSDiv(stride, builder.getInt32(dataBitwidth / 8)); + + // calc memoryAccess + bool isVolatile = memoryAccess & Builder::MemoryAccessVolatileMask; + bool isCoherent = memoryAccess & Builder::MemoryAccessCoherentMask; + bool isTemporal = memoryAccess & Builder::MemoryAccessTemporalMask; + + auto props = getTypeProperties(elemType, layout); + auto addrInfo = computeAddressing(layout, elemType, waveSize, stride, isColMajor, insertPos); + + vecVal = convCoopMatrixVecToFlatVec(builder, vecVal, elemType, layout); + + for (unsigned idx = 0; idx < props.numFlatElements; ++idx) { + Value *offset = builder.CreateAdd( + addrInfo.base, builder.CreateMul(addrInfo.macroStep, builder.getInt32(idx / addrInfo.microCount))); + offset = + builder.CreateAdd(offset, builder.CreateMul(addrInfo.microStep, builder.getInt32(idx % addrInfo.microCount))); + Value *elePtr = builder.CreateGEP(elemTy, dataPtr, offset); + Value *oneElement = builder.CreateExtractElement(vecVal, idx); + StoreInst *st = builder.CreateStore(oneElement, elePtr, isVolatile); + + if (isCoherent && !(addrSpace == ADDR_SPACE_LOCAL && dataBitwidth < 32)) + st->setAtomic(AtomicOrdering::Unordered); + if (isTemporal) + st->setMetadata(LLVMContext::MD_nontemporal, MDNode::get(builder.getContext(), {})); + } +} + +// ===================================================================================================================== +// Open-code cooperative matrix extract operation +// +// @param builder : builder to use +// @param matrix : the matrix from which to extract a component +// @param index : the index to be extracted +// @param elemType : the matrix element type +// @param layout : the matrix layout type +Value *LowerCooperativeMatrix::cooperativeMatrixExtract(BuilderCommon &builder, Value *matrix, Value *index, + Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout) { + Value *vec = convCoopMatrixVecToFlatVec(builder, matrix, elemType, layout); + + // This is a hacky workaround to the fact that for SPV_NV_cooperative_matrix, we have to support matrix length as + // a specialization constant even though, at the time of specialization constant lowering, we don't yet know the + // wave size. We should remove this once a healther KHR extension has been released. + if (layout == BuilderCommon::CooperativeMatrixLayout::AccumulatorMatrixLayout && + m_pipelineState->getShaderWaveSize(m_shaderStage) == 64) { + unsigned length = cast(vec->getType())->getNumElements(); + index = builder.CreateAnd(index, builder.getInt32(length - 1)); + } + + return builder.CreateExtractElement(vec, index); +} + +// ===================================================================================================================== +// Open-code cooperative matrix insert operation +// +// @param builder : builder to use +// @param matrix : the matrix into which to insert a component +// @param value : the value to be inserted +// @param index : the index to be inserted +// @param elemType : the matrix element type +// @param layout : the matrix layout type +Value *LowerCooperativeMatrix::cooperativeMatrixInsert(BuilderCommon &builder, Value *matrix, Value *value, + Value *index, Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout) { + Value *vec = convCoopMatrixVecToFlatVec(builder, matrix, elemType, layout); + + // This is a hacky workaround to the fact that for SPV_NV_cooperative_matrix, we have to support matrix length as + // a specialization constant even though, at the time of specialization constant lowering, we don't yet know the + // wave size. We should remove this once a healther KHR extension has been released. + if (layout == BuilderCommon::CooperativeMatrixLayout::AccumulatorMatrixLayout && + m_pipelineState->getShaderWaveSize(m_shaderStage) == 64) { + unsigned length = cast(vec->getType())->getNumElements(); + Value *outOfBounds = builder.CreateICmpUGE(index, builder.getInt32(length)); + index = builder.CreateAnd(index, builder.getInt32(length - 1)); + Value *newVec = builder.CreateInsertElement(vec, value, index); + vec = builder.CreateSelect(outOfBounds, vec, newVec); + } else { + vec = builder.CreateInsertElement(vec, value, index); + } + + return convFlatVecToCoopMatrixVec(builder, vec, elemType, layout); +} + +// ===================================================================================================================== +// Create cooperative matrix conversion without any reshape operations +// Element-wise-conversion +// @param castOp : The cast Opcode. +// @param source : The source cooperative matrix. +// @param dstElemType : Source matrix's element type. +// @param dstElemType : Destination matrix's element type. +// @param instName : Name to give instruction(s). +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::cooperativeMatrixConvertInternal(CastInst::CastOps castOp, Value *source, + Builder::CooperativeMatrixElementType srcElemType, + Builder::CooperativeMatrixElementType dstElemType, + const Twine &instName, Instruction *insertPos) { + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + Value *resultValue = nullptr; + const unsigned vecSize = cast(source->getType())->getNumElements(); + Type *dstType = FixedVectorType::get(builder.transCooperativeMatrixElementType(dstElemType), vecSize); + + if ((srcElemType == Builder::CooperativeMatrixElementType::Float16 || + srcElemType == Builder::CooperativeMatrixElementType::Float32) && + (castOp == Instruction::FPToUI || castOp == Instruction::FPToSI)) { + // FIXME: fp16's range is covered by i32. So `fptoi half` can convert + // to i32 first following a sext/zext to target integer type. + // Fix the error in: dEQP-VK.compute.cooperative_matrix.nv.convert.input_float16/32_t_output_uint8_t* + resultValue = + builder.CreateCast(castOp, source, FixedVectorType::get(builder.getInt32Ty(), vecSize), "ConvertIntoInt32"); + if (builder.transCooperativeMatrixElementType(dstElemType)->getScalarSizeInBits() < 32) { + resultValue = builder.CreateTrunc(resultValue, dstType); + } + } else { + resultValue = builder.CreateCast(castOp, source, dstType, "castOpConvert"); + } + + return resultValue; +} + +// ===================================================================================================================== +// Create cooperative matrix conversion. +// Element-wise-conversion +// @param castOp : The cast Opcode. +// @param source : The source cooperative matrix. +// @param srcElemType : Source matrix's element type. +// @param dstElemType : Destination matrix's element type. +// @param srcLayout : Layout for source matrix +// @param dstLayout : Layout for destination matrix +// @param instName : Name to give instruction(s). +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::cooperativeMatrixConvert(CastInst::CastOps castOp, Value *source, + Builder::CooperativeMatrixElementType srcElemType, + Builder::CooperativeMatrixElementType dstElemType, + Builder::CooperativeMatrixLayout srcLayout, + Builder::CooperativeMatrixLayout dstLayout, + const Twine &instName, Instruction *insertPos) { + assert(source->getType()->isVectorTy()); + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + Value *resultValue = nullptr; + Value *threadId = getLaneNumber(builder); + + if (castOp == 0) { // Only reshape on 16bits, not do convert + if ((srcLayout == Builder::CooperativeMatrixLayout::AccumulatorMatrixLayout) && + (dstLayout == Builder::CooperativeMatrixLayout::FactorMatrixLayout)) { + // After mulAdd, the type for the matrix waiting to reshape is 8*float here + const unsigned vecNums = cast(source->getType())->getNumElements(); + source = builder.CreateBitCast(source, FixedVectorType::get(builder.getInt32Ty(), vecNums)); + } + resultValue = cooperativeMatrixReshape16BitElementGfx1011(source, srcElemType, srcLayout, dstLayout, threadId, + instName, insertPos); + } else { + unsigned numSrcBit = builder.transCooperativeMatrixElementType(srcElemType)->getScalarSizeInBits(); + unsigned numDstBit = builder.transCooperativeMatrixElementType(dstElemType)->getScalarSizeInBits(); + + // Step 1: Some cases need change the layout due to different element types before conversion. + if ((numSrcBit < numDstBit) && (srcLayout != dstLayout)) { + // Need Reshape from A/B layout to C/D layout + // This interface will do cooperativeVecToflatVec internally except 8bit reshape. + source = cooperativeMatrixReshapeBeforeConvert(source, srcElemType, dstElemType, srcLayout, dstLayout, instName, + insertPos); + } else { + // For 16bit->32bit on Gfx11, no reshape needed as it will always in AccumulatorMatrixLayout + source = convCoopMatrixVecToFlatVec(builder, source, srcElemType, srcLayout); + } + + // Step 2: Just do flatElement conversion without any layout change. + resultValue = cooperativeMatrixConvertInternal(castOp, source, srcElemType, dstElemType, instName, insertPos); + + // Step 3: Some cases need change the layout due to different element types after conversion. + if ((numSrcBit > numDstBit) && (srcLayout != dstLayout)) { + // All these reshape interfaces will return N*packetTy. + // Need Reshape from A/B layout to C/D layout + resultValue = cooperativeMatrixReshapeAfterConvert(resultValue, srcElemType, dstElemType, srcLayout, dstLayout, + instName, insertPos); + } else { + resultValue = convFlatVecToCoopMatrixVec(builder, resultValue, dstElemType, dstLayout); + } + } + return resultValue; +} + +// ===================================================================================================================== +// Create cooperative matrix binary operation +// +// @param coopMatArithOp : The cooperative matrix arithmetic operation to perform. +// @param lhs : The first operand and it can be a scalar or a cooperative matrix. +// @param rhs : The second operand and it should be a cooperative matrix. +// @param elemType : Element type for the matrix. +// @param layout : Layout for the matrix. +// @param instName : Name to give instruction(s). +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::cooperativeMatrixBinaryOp(Builder::CooperativeMatrixArithOp coopMatArithOp, Value *lhs, + Value *rhs, Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout, const Twine &instName, + Instruction *insertPos) { + assert(lhs->getType()->isVectorTy() && lhs->getType() == rhs->getType() || rhs->getType()->isVectorTy()); + Value *vcResult; + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + + lhs = convCoopMatrixVecToFlatVec(builder, lhs, elemType, layout); + rhs = convCoopMatrixVecToFlatVec(builder, rhs, elemType, layout); + switch (coopMatArithOp) { + case Builder::CooperativeMatrixArithOp::IAdd: + vcResult = builder.CreateAdd(lhs, rhs); + break; + case Builder::CooperativeMatrixArithOp::FAdd: + vcResult = builder.CreateFAdd(lhs, rhs); + break; + case Builder::CooperativeMatrixArithOp::ISub: + vcResult = builder.CreateSub(lhs, rhs); + break; + case Builder::CooperativeMatrixArithOp::FSub: + vcResult = builder.CreateFSub(lhs, rhs); + break; + case Builder::CooperativeMatrixArithOp::IMul: + vcResult = builder.CreateMul(lhs, rhs); + break; + case Builder::CooperativeMatrixArithOp::FMul: + vcResult = builder.CreateFMul(lhs, rhs); + break; + case Builder::CooperativeMatrixArithOp::FDiv: + vcResult = builder.CreateFDiv(lhs, rhs); + break; + case Builder::CooperativeMatrixArithOp::SDiv: + vcResult = builder.CreateSDiv(lhs, rhs); + break; + case Builder::CooperativeMatrixArithOp::UDiv: + vcResult = builder.CreateUDiv(lhs, rhs); + break; + default: + llvm_unreachable("unsupported binary operation for cooprative matrix!"); // Rem/Mod is not supported currently. + } + + Value *coopMatResult = convFlatVecToCoopMatrixVec(builder, vcResult, elemType, layout); + return coopMatResult; +} + +// ===================================================================================================================== +// Create cooperative matrix MatrixTimesScalar operation +// +// @param matrix : The first operand and it should be a cooperative matrix. +// @param scalar : The second operand and it should be a scalar. +// @param elemType : The component type of the matrix. +// @param layout : Identify whether it's A/B or C/D +// @param instName : Name to give instruction(s). +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::coopMatrixTimesScalar(Value *matrix, Value *scalar, + Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout layout, const Twine &instName, + Instruction *insertPos) { + assert(matrix->getType()->getScalarType()->isIntegerTy() || matrix->getType()->getScalarType()->isFloatTy()); + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + + Value *vcFlat = convCoopMatrixVecToFlatVec(builder, matrix, elemType, layout); + const unsigned numElems = cast(vcFlat->getType())->getNumElements(); + auto splat = builder.CreateVectorSplat(numElems, scalar); + Value *vcFlatResult; + if ((elemType == Builder::CooperativeMatrixElementType::Float16) || + (elemType == Builder::CooperativeMatrixElementType::Float32)) { + vcFlatResult = builder.CreateFMul(vcFlat, splat); + } else { + vcFlatResult = builder.CreateMul(vcFlat, splat); + } + Value *coopMatResult = convFlatVecToCoopMatrixVec(builder, vcFlatResult, elemType, layout); + return coopMatResult; +} + +// ===================================================================================================================== +// Create cooperative matrix reshape operation only for the element is float16 +// +// @param source : The first operand and it should be a cooperative matrix. +// @param srcElemType : The component type of the matrix. +// @param srcLayout : Identify whether it's A/B or C/D +// @param dstLayout : Identify whether it's A/B or C/D +// @param castOp : Identify which cast operation is used +// @param threadId : Identify which lane +// @param instName : Name to give instruction(s). +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::cooperativeMatrixReshape16BitElementGfx1011( + Value *source, Builder::CooperativeMatrixElementType srcElemType, Builder::CooperativeMatrixLayout srcLayout, + Builder::CooperativeMatrixLayout dstLayout, Value *threadId, const Twine &instName, Instruction *insertPos) { + assert(srcElemType == Builder::CooperativeMatrixElementType::Float16 || + srcElemType == Builder::CooperativeMatrixElementType::Int16); + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + Value *resultValue = nullptr; + auto waveSize = m_pipelineState->getShaderWaveSize(m_shaderStage); + Value *laneGroupIdx = builder.CreateUDiv(threadId, builder.getInt32(16)); + Value *isEvenGroup = builder.CreateICmpEQ(builder.CreateAnd(laneGroupIdx, builder.getInt32(1)), builder.getInt32(0)); + + auto mapFuncX16 = [](BuilderBase &builder, ArrayRef mappedArgs, + ArrayRef passthroughArgs) -> Value * { + Type *const int32Ty = builder.getInt32Ty(); + + return builder.CreateIntrinsic( + int32Ty, Intrinsic::amdgcn_permlanex16, + {mappedArgs[0], mappedArgs[1], passthroughArgs[0], passthroughArgs[1], passthroughArgs[2], passthroughArgs[3]}); + }; + auto mapFunc64 = [](BuilderBase &builder, ArrayRef mappedArgs, + ArrayRef passthroughArgs) -> Value * { + Type *const int32Ty = builder.getInt32Ty(); + + return builder.CreateIntrinsic(int32Ty, Intrinsic::amdgcn_permlane64, {mappedArgs[0]}); + }; + if (srcLayout == Builder::CooperativeMatrixLayout::FactorMatrixLayout) { // From A/B to C/D for 16bit element + Type *packedTy = + (srcElemType == Builder::CooperativeMatrixElementType::Float16) ? builder.getFloatTy() : builder.getInt32Ty(); + if (dstLayout == Builder::CooperativeMatrixLayout::AccumulatorMatrixLayout) { + unsigned vecSize = cast(source->getType())->getNumElements(); + assert(vecSize == 8); // A/B should be 8*float16 or 8*int16 + unsigned shiftVecNums = 8; + // wave32/wave64: lane0: {1_0:0_0 3_0:2_0....15_0:14_0} lane15:{1_15:0_15 3_15:2_15...15_15:14_15}/lane16~lane31 + // is redundant reshape to wave32: lane0:{0_0 2_0 4_0....14_0} lane16:{1_0 3_0 5_0...15_0} wave64: lane0:{0_0 4_0 + // 8_0 12_0} lane16:{1_0 5_0 9_0 13_0} lane32:{2_0 6_0 10_0 14_0} lane48:... + resultValue = builder.CreateBitCast(source, FixedVectorType::get(builder.getInt32Ty(), vecSize)); + if (waveSize == 64) { + resultValue = PoisonValue::get(FixedVectorType::get(packedTy, 4)); + for (unsigned idx = 0; idx < vecSize; idx += 2) { + Value *low = builder.CreateExtractElement(source, idx); + Value *high = builder.CreateExtractElement(source, idx + 1); + Value *select = builder.CreateSelect( // Select between lane0-31 and lane32-63 + builder.CreateICmpEQ(builder.CreateAnd(laneGroupIdx, builder.getInt32(2)), builder.getInt32(0)), low, + high); + // Lane0: {1_0:0_0 5_0:4_0...} lane16=lane0 lane32: {3_0:2_0 7_0:6_0....} lane48=lane32 + resultValue = builder.CreateInsertElement(resultValue, select, idx / 2, instName); + } + resultValue = builder.CreateBitCast( + resultValue, + FixedVectorType::get(builder.getInt32Ty(), 4)); // Convert to 4*int32 for shl or and/or operation + shiftVecNums = 4; + } + Value *shiftZeorValue = builder.CreateVectorSplat(shiftVecNums, builder.getInt32(0)); + Value *shift16Value = builder.CreateVectorSplat(shiftVecNums, builder.getInt32(16)); + + // Wave32: lane0: {1_0:0_0 3_0:2_0....15_0:14_0} lane16: {1_0:0_0 3_0:2_0....15_0:14_0} => + // lane0: {1_0:0_0 3_0:2_0....15_0:14_0} lane16: {unused:1_0 unused:3_0....unused:15_0} + // wave64: lane0: {1_0:0_0 5_0:4_0...} lane16=lane0 lane32:{3_0:2_0 7_0:6_0....} lane48=lane32 => + // lane0: {1_0:0_0 5_0:4_0....13_0:12_0} lane16: {unused:1_0 unused:5_0....unused:13_0} lane32:{3_0:2_0 + // 7_0:6_0....} lane48: {unused:3_0 unused:7_0....} + + // 1.Bitcast matrix to + // 2.Shift right by laneGroupIndex ? 16 : 0 (you can probably do CreateLShr + // 3.Bitcast to if necessary:This will leave garbage in the upper 16 bits of some of the lanes, + // but I don't think that's a problem. + + resultValue = + builder.CreateLShr(resultValue, builder.CreateSelect(isEvenGroup, shiftZeorValue, shift16Value), instName); + if (srcElemType == Builder::CooperativeMatrixElementType::Float16) { + resultValue = builder.CreateBitCast(resultValue, FixedVectorType::get(builder.getFloatTy(), shiftVecNums), + instName); // Bitcast to 8*bit32 for wave32 and 4*bit32 for wave64 + resultValue = builder.CreateShuffleVector(resultValue, PoisonValue::get(resultValue->getType()), + {0, 1, 2, 3, 4, 5, 6, 7}); + } + } else if (dstLayout == + Builder::CooperativeMatrixLayout::Gfx10Accumulator16bitMatrixLayout) { // Emulation on NAVI2X + // from A/B to C/D on 16bit + resultValue = PoisonValue::get(FixedVectorType::get(packedTy, 8)); + // Wave32/wave64 : lane0 : {1_0:0_0 3_0:2_0....15_0:14_0} lane16 : {1_0:0_0 3_0:2_0....15_0:14_0} + // lane16 ~lane31 is redundant reshape to + // Wave32/wave64 : lane0 : {1_0:0_0 5_0:4_0....13_0:12_0} lane16 : {3_0:2_0 7_0:6_0...15_0:14_0} + source = builder.CreateBitCast(source, FixedVectorType::get(packedTy, 8)); + Value *isEvenGroup = + builder.CreateICmpEQ(builder.CreateAnd(laneGroupIdx, builder.getInt32(1)), builder.getInt32(0)); + for (unsigned idx = 0; idx < 8; idx += 2) { + Value *lowSubValue = builder.CreateExtractElement(source, idx); + Value *highSubValue = builder.CreateExtractElement(source, idx + 1); + Value *select = builder.CreateSelect(isEvenGroup, lowSubValue, highSubValue); + resultValue = builder.CreateInsertElement(resultValue, select, idx / 2, instName); + } + } else { + // It's unnecessary for reshape after gfx11. + resultValue = source; + } + } else if (srcLayout == Builder::CooperativeMatrixLayout::AccumulatorMatrixLayout) { + if (dstLayout == Builder::CooperativeMatrixLayout::FactorMatrixLayout) { + // lane0----lan16----lane32-----lane48*/ + // 1x-------1y-------1m---------1n*/ + // ==> */ + // {1y,1x}---{1y,1x}--{1n,1m}----{1n,1m},*/ + + // Source now is 8*half not care wave32 or wave64 + // Zext to 8*int@wave64, the upper 16bits will not be used. + // Permulate lane and composite the elements showns as above. + // There will be two cases when change accumulator layout(32bit) to factor layout(16bit): + // 1. Convert on the element: float32(fptrunc)->float16 + // 2. Reshape after MulAdd(float16*float16+float16)->Need change C/D layout to A/B layout + // So it needs using castOp to identify which case happened. + unsigned vecNums = 8; + Value *matrix = + builder.CreateShuffleVector(source, PoisonValue::get(source->getType()), {0, 1, 2, 3, 4, 5, 6, 7}); + + static const unsigned LaneSelBits[2] = {0x76543210, 0xfedcba98}; + Value *swapped = builder.CreateMapToSimpleType( + mapFuncX16, + { + matrix, + matrix, + }, + {builder.getInt32(LaneSelBits[0]), builder.getInt32(LaneSelBits[1]), builder.getFalse(), builder.getFalse()}); + + Value *first = builder.CreateSelect(isEvenGroup, matrix, swapped); + Value *second = builder.CreateSelect(isEvenGroup, swapped, matrix); + + Value *shiftValue = builder.CreateVectorSplat(vecNums, builder.getInt32(16)); + Value *maskValue = builder.CreateVectorSplat(vecNums, builder.getInt32(0xffff)); + + Value *maskedFirst = builder.CreateAnd(first, maskValue); + matrix = builder.CreateOr(maskedFirst, builder.CreateShl(second, shiftValue)); + + // For wave64: step1: merge lane0+lane32 lane16+lane48 + // Each lane value: float/int32 * 4+ poison value*4 + // lane0:{1_0:0_0 5_0:4_0...} lane16:{1_0:0_0 5_0:4_0...} lane32:{3_0:2_0 7_0:6_0...} lane48{3_0:2_0 7_0:6_0...} + // --shuffle--> lane0: {1_0:0_0 3_0:2_0 5_0:4_0....} lane16:{1_0:0_0 3_0:2_0 5_0:4_0....} lane32: {1_0:0_0 + // 3_0:2_0 5_0:4_0....} lane48:{1_0:0_0 3_0:2_0 5_0:4_0....} For wave32: lane0: {1_0:0_0 3_0:2_0 5_0:4_0....} + // lane16:{1_0:0_0 3_0:2_0 5_0:4_0....} lane32=lane0 lanes48=lane16 + + if (waveSize == 64) { + Value *swapped = builder.CreateMapToSimpleType(mapFunc64, matrix, {}); + Value *const laneIdLessThan32 = builder.CreateICmpULT(threadId, builder.getInt32(32)); + Value *first = builder.CreateSelect(laneIdLessThan32, matrix, swapped); + Value *second = builder.CreateSelect(laneIdLessThan32, swapped, matrix); + matrix = builder.CreateShuffleVector(first, second, ArrayRef({0, 8, 1, 9, 2, 10, 3, 11}), instName); + } + // After shuffle wave64's layout is same with wave32 + if (srcElemType == Builder::CooperativeMatrixElementType::Float16) { + matrix = builder.CreateBitCast(matrix, FixedVectorType::get(builder.getFloatTy(), 8)); //->8*f32 + } + resultValue = matrix; + } + } else if (srcLayout == Builder::CooperativeMatrixLayout::Gfx10Accumulator16bitMatrixLayout) { + if (dstLayout == Builder::CooperativeMatrixLayout::FactorMatrixLayout) { // NAVI2X:16bit reshape C/D->A/B + // C/D: LANE0: {1_0:0_0 5_0:4_0 9_0:8_0 13_0:12_0} LANE16:{3_0:2_0 7_0:6_0 11_0:10_0 15_0:14_0}===> + // A/B: LANE0: {1_0:0_0 3_0:2_0 5_0:4:0....15_0:14_0} LANE16=LANE0 + Type *packedTy = + (srcElemType == Builder::CooperativeMatrixElementType::Float16) ? builder.getFloatTy() : builder.getInt32Ty(); + resultValue = PoisonValue::get(FixedVectorType::get(packedTy, 8)); + unsigned LaneSelBits[2] = {0x76543210, 0xfedcba98}; + Value *swapped = builder.CreateMapToSimpleType( + mapFuncX16, + { + source, + source, + }, + {builder.getInt32(LaneSelBits[0]), builder.getInt32(LaneSelBits[1]), builder.getFalse(), builder.getFalse()}); + + Value *first = builder.CreateSelect(isEvenGroup, source, swapped); + Value *second = builder.CreateSelect(isEvenGroup, swapped, source); + for (unsigned idx = 0; idx < 8; idx += 2) { // A/B will always be V8 + Value *firstValue = builder.CreateExtractElement(first, idx / 2); + Value *secondValue = builder.CreateExtractElement(second, idx / 2); + resultValue = builder.CreateInsertElement(resultValue, firstValue, idx, instName); + resultValue = builder.CreateInsertElement(resultValue, secondValue, idx + 1, instName); + } + } + } else { + llvm_unreachable("The layout is not supported."); + } + return resultValue; +} + +// ===================================================================================================================== +// Create cooperative matrix reshape operation only for the element is int8 +// +// @param source : The first operand and it should be a cooperative matrix. +// @param srcElemType : The component type of the matrix. +// @param srcLayout : Identify whether it's A/B or C/D +// @param instName : Name to give instruction(s). +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::cooperativeMatrixReshapeBetween8bitAnd32bitElementGfx1011( + Value *source, Builder::CooperativeMatrixElementType srcElemType, Builder::CooperativeMatrixLayout srcLayout, + const Twine &instName, Instruction *insertPos) { + + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + Value *resultValue = nullptr; + auto waveSize = m_pipelineState->getShaderWaveSize(m_shaderStage); + Value *threadId = getLaneNumber(builder); + Value *laneGroupIdx = builder.CreateUDiv(threadId, builder.getInt32(16)); + Value *isEvenGroup = builder.CreateICmpEQ(builder.CreateAnd(laneGroupIdx, builder.getInt32(1)), builder.getInt32(0)); + + if (srcLayout == Builder::CooperativeMatrixLayout::FactorMatrixLayout) { + assert(srcElemType == Builder::CooperativeMatrixElementType::Int8); + Value *int8Value = builder.CreateBitCast(source, FixedVectorType::get(builder.getInt8Ty(), 16)); + if ((waveSize == 32) || (m_gfxIp.major < 11)) { + Value *lowValue = builder.CreateShuffleVector(int8Value, ArrayRef({0, 2, 4, 6, 8, 10, 12, 14})); + Value *highValue = builder.CreateShuffleVector(int8Value, ArrayRef({1, 3, 5, 7, 9, 11, 13, 15})); + resultValue = builder.CreateSelect(isEvenGroup, lowValue, highValue, instName); + } else { + Value *lowlowValue = builder.CreateShuffleVector(int8Value, ArrayRef({0, 4, 8, 12})); + Value *lowhighValue = builder.CreateShuffleVector(int8Value, ArrayRef({1, 5, 9, 13})); + Value *highlowValue = builder.CreateShuffleVector(int8Value, ArrayRef({2, 6, 10, 14})); + Value *highhighValue = builder.CreateShuffleVector(int8Value, ArrayRef({3, 7, 11, 15})); + + Value *const laneIdLessThan32 = builder.CreateICmpULT(threadId, builder.getInt32(32)); + Value *isEvenGroupLessThan32 = builder.CreateAnd(laneIdLessThan32, isEvenGroup); + Value *isOddGroupLessThan32 = builder.CreateAnd(laneIdLessThan32, builder.CreateNot(isEvenGroup)); + Value *isEvenGroupMoreThan32 = builder.CreateAnd(builder.CreateNot(laneIdLessThan32), isEvenGroup); + Value *isOddGroupMoreThan32 = + builder.CreateAnd(builder.CreateNot(laneIdLessThan32), builder.CreateNot(isEvenGroup)); + + resultValue = lowlowValue; + resultValue = builder.CreateSelect(isEvenGroupLessThan32, lowlowValue, resultValue, instName); + resultValue = builder.CreateSelect(isOddGroupLessThan32, lowhighValue, resultValue, instName); + resultValue = builder.CreateSelect(isEvenGroupMoreThan32, highlowValue, resultValue, instName); + resultValue = builder.CreateSelect(isOddGroupMoreThan32, highhighValue, resultValue, instName); + } + } else if (srcLayout == Builder::CooperativeMatrixLayout::AccumulatorMatrixLayout || + srcLayout == Builder::CooperativeMatrixLayout::Gfx10AccumulatorMatrixLayout) { + // + assert(srcElemType == Builder::CooperativeMatrixElementType::Int32 || + srcElemType == Builder::CooperativeMatrixElementType::Float32); + // unsigned vecSize = cast(source->getType())->getNumElements(); + unsigned vecSize = 8; + source = builder.CreateShuffleVector(source, PoisonValue::get(source->getType()), {0, 1, 2, 3, 4, 5, 6, 7}); + unsigned LaneSelBits[2] = {0x76543210, 0xfedcba98}; + auto mapFuncX16 = [](BuilderBase &builder, ArrayRef mappedArgs, + ArrayRef passthroughArgs) -> Value * { + Type *const int32Ty = builder.getInt32Ty(); + + return builder.CreateIntrinsic(int32Ty, Intrinsic::amdgcn_permlanex16, + {mappedArgs[0], mappedArgs[1], passthroughArgs[0], passthroughArgs[1], + passthroughArgs[2], passthroughArgs[3]}); + }; + + Value *swapped = builder.CreateMapToSimpleType( + mapFuncX16, + { + source, + source, + }, + {builder.getInt32(LaneSelBits[0]), builder.getInt32(LaneSelBits[1]), builder.getFalse(), builder.getFalse()}); + + Value *first = builder.CreateSelect(isEvenGroup, source, swapped); + Value *second = builder.CreateSelect(isEvenGroup, swapped, source); + Value *afterPermValue = PoisonValue::get(FixedVectorType::get(builder.getInt8Ty(), vecSize * 2)); + for (unsigned idx = 0; idx < vecSize * 2; idx += 2) { + Value *firstElement = builder.CreateExtractElement(first, idx / 2); + Value *secondElement = builder.CreateExtractElement(second, idx / 2); + afterPermValue = builder.CreateInsertElement(afterPermValue, firstElement, idx, "firstElement"); + afterPermValue = builder.CreateInsertElement(afterPermValue, secondElement, idx + 1, "secondElement"); + } + afterPermValue = builder.CreateBitCast(afterPermValue, FixedVectorType::get(builder.getInt16Ty(), vecSize)); + + if ((m_gfxIp.major == 11) && (waveSize == 64)) { + auto mapFunc64 = [](BuilderBase &builder, ArrayRef mappedArgs, + ArrayRef passthroughArgs) -> Value * { + Type *const int32Ty = builder.getInt32Ty(); + + return builder.CreateIntrinsic(int32Ty, Intrinsic::amdgcn_permlane64, {mappedArgs[0]}); + }; + Value *swapped = builder.CreateMapToSimpleType(mapFunc64, afterPermValue, {}); + + Value *const laneIdLessThan32 = builder.CreateICmpULT(threadId, builder.getInt32(32)); + Value *first = builder.CreateSelect(laneIdLessThan32, afterPermValue, swapped); + Value *second = builder.CreateSelect(laneIdLessThan32, swapped, afterPermValue); + afterPermValue = builder.CreateShuffleVector(first, second, ArrayRef({0, 8, 1, 9, 2, 10, 3, 11})); // 8*int16 + } + // bitcast: lane0 : {1_0:0_0 3_0:2_0... }(8 * int16) lane16 : {1_0 : 0_0 3_0 : 2_0...}(8 * int16) to + // lane0 : {3_0:2_0:1_0:0_0...}(4*int32) */ + resultValue = + builder.CreateBitCast(afterPermValue, FixedVectorType::get(builder.getInt32Ty(), 4), "Int16V8ToInt32V4"); + } else { + llvm_unreachable("The layout is not supported."); + } + return resultValue; +} + +// ===================================================================================================================== +// Change the 16bit layout for fconvert from f16(f32) to f32(f16) +// +// @param source : The first operand and it should be a cooperative matrix. +// @param srcLayout : Identify whether it's A/B or C/D +// @param dstLayout : Identify whether it's A/B or C/D +// @param isEvenGroup : Identify which row +// @param instName : Name to give instruction(s). +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::cooperativeMatrixReshapeBetween16bitAnd32bitOnAccGfx10( + Value *source, Builder::CooperativeMatrixElementType srcElemType, Builder::CooperativeMatrixElementType dstElemType, + Builder::CooperativeMatrixLayout layout, Value *isEvenGroup, const Twine &instName, Instruction *insertPos) { + // 1. After convert from f32->f16: change the layout from 32bit layout to 16bit layout on Accumulator on gfx10. + // 2. Before convert from f16->f32: change the layout from 16bit layout to 32bit layout on Accumulator on gfx10 + + // For 1st case: lane0:{0_0 2_0 4_0..14_0} lane16:{1_0 3_0 5_0...15_0} lane32=lane0 lane48=lane16(8*half) ==> + // lane0:{1_0:0_0 5_0:4_0 ....} lane16:{3_0:2_0 7_0:6_0..} (4*float) + // For 2nd case: lane0:{1_0:0_0 5_0:4_0 ....} lane16:{3_0:2_0 7_0:6_0..}(4*float) ==> + // lane0:{0_0 2_0 4_0..14_0} lane16:{1_0 3_0 5_0...15_0}(8*half) + // From the implementation side, it's same which only exchange off-diaglog element between {2_0:0_0} and {3_0:1_0}(1st + // case) + // or {1_0:0_0} and {3_0:2_0}(2nd case) + assert(layout == Builder::CooperativeMatrixLayout::Gfx10AccumulatorMatrixLayout || + layout == Builder::CooperativeMatrixLayout::Gfx10Accumulator16bitMatrixLayout); + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + + Value *resultValue = nullptr; + if (dstElemType == Builder::CooperativeMatrixElementType::Float16 || + dstElemType == Builder::CooperativeMatrixElementType::Int16) { + source = builder.CreateBitCast(source, FixedVectorType::get(builder.getInt32Ty(), 4)); + } else if (dstElemType == Builder::CooperativeMatrixElementType::Float32 || + dstElemType == Builder::CooperativeMatrixElementType::Int32) { + source = builder.CreateBitCast(source, FixedVectorType::get(builder.getInt32Ty(), 8)); + } + unsigned LaneSelBits[2] = {0x76543210, 0xfedcba98}; + auto mapFuncX16 = [](BuilderBase &builder, ArrayRef mappedArgs, + ArrayRef passthroughArgs) -> Value * { + Type *const int32Ty = builder.getInt32Ty(); + + return builder.CreateIntrinsic( + int32Ty, Intrinsic::amdgcn_permlanex16, + {mappedArgs[0], mappedArgs[1], passthroughArgs[0], passthroughArgs[1], passthroughArgs[2], passthroughArgs[3]}); + }; + Value *matrix = source; + Value *swapped = builder.CreateMapToSimpleType( + mapFuncX16, + { + matrix, + matrix, + }, + {builder.getInt32(LaneSelBits[0]), builder.getInt32(LaneSelBits[1]), builder.getFalse(), builder.getFalse()}); + + unsigned shiftVecNums = cast(swapped->getType())->getNumElements(); + Value *maskLowValue = builder.CreateVectorSplat(shiftVecNums, builder.getInt32(0x0000ffff)); + Value *maskHighValue = builder.CreateVectorSplat(shiftVecNums, builder.getInt32(0xffff0000)); + Value *shiftValue = builder.CreateVectorSplat(shiftVecNums, builder.getInt32(16)); + + Value *maskedSourceLow = builder.CreateAnd(source, maskLowValue); + Value *lowVal = builder.CreateSelect(isEvenGroup, maskedSourceLow, + builder.CreateAnd(builder.CreateLShr(swapped, shiftValue), maskLowValue)); + + Value *maskedSourceHigh = builder.CreateAnd(source, maskHighValue); + Value *highVal = builder.CreateSelect( + isEvenGroup, builder.CreateAnd(builder.CreateShl(swapped, shiftValue), maskHighValue), maskedSourceHigh); + resultValue = builder.CreateOr(highVal, lowVal); + + if (srcElemType == Builder::CooperativeMatrixElementType::Float16 && + (dstElemType == Builder::CooperativeMatrixElementType::Float32 || + dstElemType == Builder::CooperativeMatrixElementType::Int32)) { + resultValue = + builder.CreateBitCast(resultValue, FixedVectorType::get(builder.getHalfTy(), 16)); // 2nd case:before convert + } else { + resultValue = + builder.CreateBitCast(resultValue, FixedVectorType::get(builder.getFloatTy(), 4)); // 1st case:after convert + } + return resultValue; +} + +// ===================================================================================================================== +// Adjust the layout before reshape operation from small size type into large size type(eg:float16->float32) +// +// @param source : The first operand and it should be a cooperative matrix. +// @param srcLayout : Identify whether it's A/B or C/D +// @param dstLayout : Identify whether it's A/B or C/D +// @param srcElemType : The source component type of the matrix. +// @param dstElemType : The destination component type of the matrix. +// @param isEvenGroup : Identify which row +// @param instName : Name to give instruction(s). +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::cooperativeMatrixReshapeBeforeConvert(Value *source, + Builder::CooperativeMatrixElementType srcElemType, + Builder::CooperativeMatrixElementType dstElemType, + Builder::CooperativeMatrixLayout srcLayout, + Builder::CooperativeMatrixLayout dstLayout, + const Twine &instName, Instruction *insertPos) { + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + Value *resultValue = source; + + Value *threadId = getLaneNumber(builder); + Value *laneGroupIdx = builder.CreateUDiv(threadId, builder.getInt32(16)); + Value *isEvenGroup = builder.CreateICmpEQ(builder.CreateAnd(laneGroupIdx, builder.getInt32(1)), builder.getInt32(0)); + + if (srcElemType == Builder::CooperativeMatrixElementType::Float16 || + srcElemType == Builder::CooperativeMatrixElementType::Int16) { + if (srcLayout == Builder::CooperativeMatrixLayout::FactorMatrixLayout && + dstLayout == Builder::CooperativeMatrixLayout::AccumulatorMatrixLayout) { + resultValue = cooperativeMatrixReshape16BitElementGfx1011(source, srcElemType, srcLayout, dstLayout, threadId, + "reshapeFactorToAcc", insertPos); + resultValue = convCoopMatrixVecToFlatVec(builder, resultValue, srcElemType, dstLayout); + } else if (srcLayout == Builder::CooperativeMatrixLayout::FactorMatrixLayout && + dstLayout == Builder::CooperativeMatrixLayout::Gfx10AccumulatorMatrixLayout) { + resultValue = cooperativeMatrixReshape16BitElementGfx1011( + source, srcElemType, srcLayout, Builder::CooperativeMatrixLayout::Gfx10Accumulator16bitMatrixLayout, threadId, + "reshapeFactorToAcc", insertPos); + resultValue = cooperativeMatrixReshapeBetween16bitAnd32bitOnAccGfx10( + resultValue, srcElemType, dstElemType, dstLayout, isEvenGroup, "beforef16tof32", insertPos); + resultValue = convCoopMatrixVecToFlatVec(builder, resultValue, srcElemType, dstLayout); + } else if (srcLayout == Builder::CooperativeMatrixLayout::Gfx10Accumulator16bitMatrixLayout && + dstLayout == Builder::CooperativeMatrixLayout::Gfx10AccumulatorMatrixLayout) { + resultValue = cooperativeMatrixReshapeBetween16bitAnd32bitOnAccGfx10(source, srcElemType, dstElemType, dstLayout, + isEvenGroup, "beforef16tof32", insertPos); + resultValue = convCoopMatrixVecToFlatVec(builder, resultValue, srcElemType, dstLayout); + } else { + llvm_unreachable("Unsupported layout!"); + } + } else if (srcElemType == Builder::CooperativeMatrixElementType::Int8) { + // 8bit already return the N*flatType, it's unnecessary to call convCoopMatrixVecToFlatVec + if (srcLayout == Builder::CooperativeMatrixLayout::FactorMatrixLayout) { + resultValue = cooperativeMatrixReshapeBetween8bitAnd32bitElementGfx1011(source, srcElemType, srcLayout, + "reshapeFactorToAcc", insertPos); + } else { + // 8bit->32bit, no reshape is necessary as all elements are sorted consistently between 8bitLayout and + // 32bitLayout. + resultValue = convCoopMatrixVecToFlatVec(builder, resultValue, srcElemType, srcLayout); + } + } else { + resultValue = convCoopMatrixVecToFlatVec(builder, resultValue, srcElemType, srcLayout); + } + return resultValue; +} + +// ===================================================================================================================== +// Adjust the layout after reshape operation from large size type into small size type(eg:float32->float16) +// +// @param source : The first operand and it should be a cooperative matrix. +// @param srcLayout : Identify whether it's A/B or C/D +// @param dstLayout : Identify whether it's A/B or C/D +// @param srcElemType : The source component type of the matrix. +// @param dstElemType : The destination component type of the matrix. +// @param isEvenGroup : Identify which row +// @param instName : Name to give instruction(s). +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::cooperativeMatrixReshapeAfterConvert(Value *source, + Builder::CooperativeMatrixElementType srcElemType, + Builder::CooperativeMatrixElementType dstElemType, + Builder::CooperativeMatrixLayout srcLayout, + Builder::CooperativeMatrixLayout dstLayout, + const Twine &instName, Instruction *insertPos) { + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + Value *resultValue = source; + + Value *threadId = getLaneNumber(builder); + Value *laneGroupIdx = builder.CreateUDiv(threadId, builder.getInt32(16)); + Value *isEvenGroup = builder.CreateICmpEQ(builder.CreateAnd(laneGroupIdx, builder.getInt32(1)), builder.getInt32(0)); + + if (dstElemType == Builder::CooperativeMatrixElementType::Float16 || + dstElemType == Builder::CooperativeMatrixElementType::Int16) { + if (srcLayout == Builder::CooperativeMatrixLayout::AccumulatorMatrixLayout && + dstLayout == Builder::CooperativeMatrixLayout::FactorMatrixLayout) { + // It needs to convert 16bit*8 into 32bit*8(high 16bit will be unused) as + // the input for reshape interface will be 32bit*8 keeping compatibility for reshape+muladd+reshape case. + resultValue = + builder.CreateShuffleVector(resultValue, PoisonValue::get(source->getType()), {0, 1, 2, 3, 4, 5, 6, 7}); + resultValue = builder.CreateBitCast(resultValue, FixedVectorType::get(builder.getInt16Ty(), 8)); + resultValue = builder.CreateZExt(resultValue, FixedVectorType::get(builder.getInt32Ty(), 8), "zext"); + resultValue = cooperativeMatrixReshape16BitElementGfx1011(resultValue, dstElemType, srcLayout, dstLayout, + threadId, "reshapeAccToFactor", insertPos); + } else if (srcLayout == Builder::CooperativeMatrixLayout::Gfx10AccumulatorMatrixLayout && + dstLayout == Builder::CooperativeMatrixLayout::FactorMatrixLayout) { + resultValue = cooperativeMatrixReshapeBetween16bitAnd32bitOnAccGfx10(source, srcElemType, dstElemType, srcLayout, + isEvenGroup, "afterf32tof16", insertPos); + resultValue = cooperativeMatrixReshape16BitElementGfx1011( + resultValue, dstElemType, Builder::CooperativeMatrixLayout::Gfx10Accumulator16bitMatrixLayout, dstLayout, + threadId, "reshapeAccToFactor", insertPos); + } else if (srcLayout == Builder::CooperativeMatrixLayout::Gfx10AccumulatorMatrixLayout && + dstLayout == Builder::CooperativeMatrixLayout::Gfx10Accumulator16bitMatrixLayout) { + resultValue = cooperativeMatrixReshapeBetween16bitAnd32bitOnAccGfx10(source, srcElemType, dstElemType, srcLayout, + isEvenGroup, "afterf32tof16", insertPos); + } else { + llvm_unreachable("Unsupported elemtype!"); + } + } else if (dstElemType == Builder::CooperativeMatrixElementType::Int8) { + if (dstLayout == Builder::CooperativeMatrixLayout::FactorMatrixLayout) { // gfx10/gfx11: 32bit->8bit + resultValue = cooperativeMatrixReshapeBetween8bitAnd32bitElementGfx1011(source, srcElemType, srcLayout, + "reshapeFactorToAcc", insertPos); + } else { + // 32bit->8bit, no reshape is necessary as all elements are sorted consistently between 8bitLayout and + // 32bitLayout. + resultValue = convFlatVecToCoopMatrixVec(builder, resultValue, dstElemType, dstLayout); + } + } + return resultValue; +} + +// ===================================================================================================================== +// Create cooperative matrix transpose operation +// +// @param matrix : The first operand and it should be a cooperative matrix. +// @param elemType : The component type of the matrix. +// @param srcLayout: Identify whether it's A/B or C/D +// @param instName : Name to give instruction(s). +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::cooperativeMatrixTranspose(llvm::Value *matrix, + Builder::CooperativeMatrixElementType elemType, + Builder::CooperativeMatrixLayout srcLayout, + const Twine &instName, llvm::Instruction *insertPos) { + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + + Value *threadId = getLaneNumber(builder); + Value *isEvenThread = builder.CreateICmpEQ(builder.CreateAnd(threadId, builder.getInt32(1)), builder.getInt32(0)); + unsigned vecSize = cast(matrix->getType())->getNumElements(); + unsigned vecStride, laneStride; + + auto mapFuncDpp8 = [](BuilderBase &builder, ArrayRef mappedArgs, + ArrayRef passthroughArgs) -> Value * { + return builder.CreateIntrinsic(Intrinsic::amdgcn_mov_dpp8, builder.getInt32Ty(), + {mappedArgs[0], passthroughArgs[0]}); + }; + + auto mapFuncPerm = [](BuilderBase &builder, ArrayRef mappedArgs, + ArrayRef passthroughArgs) -> Value * { + Type *const int32Ty = builder.getInt32Ty(); + + return builder.CreateIntrinsic(int32Ty, Intrinsic::amdgcn_perm, {mappedArgs[0], mappedArgs[1], passthroughArgs[0]}); + }; + + Value *dpp8 = builder.getInt32(1 | 0 << 3 | 3 << 6 | 2 << 9 | 5 << 12 | 4 << 15 | 7 << 18 | 6 << 21); + Value *matrixShuffle = builder.CreateMapToSimpleType(mapFuncDpp8, matrix, {dpp8}); + + if (elemType == Builder::CooperativeMatrixElementType::Int8) { + // 1st step: {3_0:2_0:1_0:0_0} {3_1:2_1:1_1:0_1} -> + // {0_1:0_0:2_1:2_0} {1_1:1_0:3_1:3_0} + + Value *transValueEven = + builder.CreateMapToSimpleType(mapFuncPerm, {matrixShuffle, matrix}, builder.getInt32(0x04000602)); + Value *transValueOdd = + builder.CreateMapToSimpleType(mapFuncPerm, {matrixShuffle, matrix}, builder.getInt32(0x01050307)); + Value *transValue = builder.CreateSelect(isEvenThread, transValueEven, transValueOdd); + + // 2nd step + // 0_1:0_0:2_1:2_0----1_1:1_0:3_1:3_0-----0_3:0_2:2_3:2_2----1_3:1_2:3_3:3_2 ==> + // 0_3:0_2:0_1:0_0 1_3:1_2:1_1:1_0 2_3:2_2:2_1:2_0 3_3:3_2:3_1:3_0 + + dpp8 = builder.getInt32(2 | 3 << 3 | 0 << 6 | 1 << 9 | 6 << 12 | 7 << 15 | 4 << 18 | 5 << 21); + + Value *transValueShuffle = builder.CreateMapToSimpleType(mapFuncDpp8, transValue, {dpp8}); + Value *srclowlane = builder.CreateICmpEQ(builder.CreateAnd(threadId, builder.getInt32(2)), builder.getInt32(0)); + + Value *matrixSlow = + builder.CreateMapToSimpleType(mapFuncPerm, {transValueShuffle, transValue}, builder.getInt32(0x07060302)); + Value *matrixHigh = + builder.CreateMapToSimpleType(mapFuncPerm, {transValueShuffle, transValue}, builder.getInt32(0x01000504)); + matrix = builder.CreateSelect(srclowlane, matrixSlow, matrixHigh); + + vecStride = 1; + laneStride = 4; + } else if (elemType == Builder::CooperativeMatrixElementType::Int16 || + elemType == Builder::CooperativeMatrixElementType::Float16) { + // lane0:{1_0, 0_0} lane1:{1_1,0_1} -> lane0: {0_1, 0_0} lane1:{1_1, 1_0} + matrix = builder.CreateBitCast(matrix, FixedVectorType::get(builder.getInt32Ty(), vecSize)); + matrixShuffle = builder.CreateBitCast(matrixShuffle, FixedVectorType::get(builder.getInt32Ty(), vecSize)); + Value *shiftValue = builder.CreateVectorSplat(vecSize, builder.getInt32(16)); + Value *highmaskValue = builder.CreateVectorSplat(vecSize, builder.getInt32(0xFFFF0000)); + Value *lowmaskValue = builder.CreateVectorSplat(vecSize, builder.getInt32(0x0000FFFF)); + + Value *maskedMatrixHigh = builder.CreateAnd(matrix, highmaskValue); + Value *high = builder.CreateSelect(isEvenThread, builder.CreateShl(matrixShuffle, shiftValue), maskedMatrixHigh); + Value *maskedMatrixLow = builder.CreateAnd(matrix, lowmaskValue); + Value *low = builder.CreateSelect(isEvenThread, maskedMatrixLow, builder.CreateLShr(matrixShuffle, shiftValue)); + matrix = builder.CreateOr(high, low); + if (elemType == Builder::CooperativeMatrixElementType::Float16) { + matrix = builder.CreateBitCast(matrix, FixedVectorType::get(builder.getFloatTy(), vecSize)); + } + vecStride = 1; + laneStride = 2; + } else { + llvm_unreachable("Element type is not supported."); + } + + // lane0/V0: {0_0,0_1}; V1: {2_0,2_1} lane2/V0:{0_2,0_3} V1:{2_2,2_3} ==> + // lane0/V0: {0_0,0_1}; V1: {0_2,0_3} lane2/V0:{2_0,2_1} V1:{2_2,2_3} + Value *resultValue = transposeCooperativeMatrixRecursively(matrix, vecStride, laneStride, threadId, builder); + return resultValue; +} + +// ===================================================================================================================== +// Create cooperative matrix transpose +// @param matrix : The first operand and it should be a cooperative matrix. +// @param vecStride : Identify stride in element vector when transpose block size +// @param laneStride : Identify stride in lane when transpose block size +// @param threadId : Current threadId. +// @param builder : The IR builder to create and insert IR instruction +Value *LowerCooperativeMatrix::transposeCooperativeMatrixRecursively(llvm::Value *matrix, unsigned vecStride, + unsigned laneStride, Value *threadId, + BuilderBase &builder) { + unsigned vgprNums = cast(matrix->getType())->getNumElements(); + if (vecStride >= vgprNums) { + return matrix; + } + + DppCtrl dppCtrl = DppCtrl::DppQuadPerm0000; + switch (laneStride) { + case 1: + dppCtrl = DppCtrl::DppRowXmask1; + break; + case 2: + dppCtrl = DppCtrl::DppRowXmask2; + break; + case 4: + dppCtrl = DppCtrl::DppRowXmask4; + break; + case 8: + dppCtrl = DppCtrl::DppRowXmask8; + break; + default: + llvm_unreachable("The stride is not correct!"); + } + + auto mapFuncdppmove = [](BuilderBase &builder, ArrayRef mappedArgs, + ArrayRef passthroughArgs) -> Value * { + return builder.CreateIntrinsic( + Intrinsic::amdgcn_mov_dpp, builder.getInt32Ty(), + {mappedArgs[0], passthroughArgs[0], passthroughArgs[1], passthroughArgs[2], passthroughArgs[3]}); + }; + + Value *transResultValue = PoisonValue::get(matrix->getType()); + Value *replaceLaneValue = + builder.CreateMapToSimpleType(mapFuncdppmove, matrix, + {builder.getInt32(static_cast(dppCtrl)), builder.getInt32(15), + builder.getInt32(15), builder.getTrue()}); + + Value *swapFlag = builder.CreateICmpNE(builder.CreateAnd(threadId, laneStride), builder.getInt32(0)); + Value *inverseSwapFlag = builder.CreateNot(swapFlag); + + for (int index = 0; index < vgprNums; ++index) { + Value *srcValue = builder.CreateExtractElement(matrix, index); + unsigned replaceValueIndex = index ^ vecStride; + Value *replaceValue = builder.CreateExtractElement(replaceLaneValue, replaceValueIndex); + // if (i & n) == 0: + // dst = swapFlag ? tmp : src + // else: + // dst = inverseSwapFlag ? tmp : src + Value *dst; + if ((index & vecStride) == 0) { + dst = builder.CreateSelect(swapFlag, replaceValue, srcValue); + } else { + dst = builder.CreateSelect(inverseSwapFlag, replaceValue, srcValue); + } + transResultValue = builder.CreateInsertElement(transResultValue, dst, index); + } + + vecStride = vecStride << 1; + laneStride = laneStride << 1; + + transResultValue = transposeCooperativeMatrixRecursively(transResultValue, vecStride, laneStride, threadId, builder); + return transResultValue; +} + +// ===================================================================================================================== +// Create cooperative matrix muladd operation +// +// @param matrixA : Factor cooperative matrix. +// @param matrixB : Factor cooperative matrix. +// @param matrixC : Accumulator cooperative matrix. +// @param isSignedA : Identify the signess for matrix A's element type +// @param isSignedB : Identify the signess for matrix B's element type +// @param isSat : SaturatingAccumulation for calculation +// @param accumElemType : The component type of the accumulator matrix. +// @param factorElemType : The component type of the factor matrix. +// @param matrixCLayout: The layout for matrix C/D. +// @param instName : Name to give instruction(s). +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::cooperativeMatrixMulAdd(llvm::Value *matrixA, llvm::Value *matrixB, llvm::Value *matrixC, + bool isSignedA, bool isSignedB, bool isSat, + Builder::CooperativeMatrixElementType accumElemType, + Builder::CooperativeMatrixElementType factorElemType, + const Twine &instName, Instruction *insertPos) { + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + + if (m_gfxIp.major >= 11) { + // Gfx11: + // wave64: + // declare <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half>, <16 x half>, <4 x float>) + // declare <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16(<16 x i16>, <16 x i16>, <4 x float>) + // declare <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half>, <16 x half>, <8 x half>, i1 immarg) + // declare <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16(<16 x i16>, <16 x i16>, <8 x i16>, i1 immarg) + // declare <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8(i1 immarg, <4 x i32>, i1 immarg, <4 x i32>, <4 x i32>, i1 + // immarg) declare <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4(i1 immarg, <2 x i32>, i1 immarg, <2 x i32>, <4 x + // i32>, i1 immarg) + // wave32: + // declare <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half>, <16 x half> , <8 x float>) + // declare <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16(<16 x i16>, <16 x i16> , <8 x float>) + // declare <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half>, <16 x half> , <16 x half>, i1 immarg) + // declare <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16(<16 x i16>, <16 x i16> , <16 x i16>, i1 immarg) + // declare <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8(i1 immarg, <4 x i32>, i1 immarg, <4 x i32> , <8 x i32>, i1 + // immarg) declare <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4(i1 immarg, <2 x i32>, i1 immarg, <2 x i32> , <8 x + // i32>, i1 immarg) + Value *matrixD; + unsigned waveSize = m_pipelineState->getShaderWaveSize(m_shaderStage); + + if (factorElemType == Builder::CooperativeMatrixElementType::Float16 || + factorElemType == Builder::CooperativeMatrixElementType::Int16) { + unsigned factorFlatElemNum = 0; + { factorFlatElemNum = 16; } + Type *factorType = + FixedVectorType::get(builder.transCooperativeMatrixElementType(factorElemType), factorFlatElemNum); + matrixA = builder.CreateBitCast(matrixA, factorType); + matrixB = builder.CreateBitCast(matrixB, factorType); + } else if (factorElemType == Builder::CooperativeMatrixElementType::Int8) { + } else { + llvm_unreachable("Factor element type is not supported!"); + } + + if (accumElemType == Builder::CooperativeMatrixElementType::Float32 || + accumElemType == Builder::CooperativeMatrixElementType::Int32) { + matrixC = + waveSize == 64 ? builder.CreateShuffleVector(matrixC, ArrayRef({0, 1, 2, 3}), "shuffleVector") : matrixC; + } else if (accumElemType == Builder::CooperativeMatrixElementType::Float16 || + accumElemType == Builder::CooperativeMatrixElementType::Int16) { + { + matrixC = waveSize == 64 ? builder.CreateShuffleVector(matrixC, ArrayRef({0, 1, 2, 3}), "shuffleVector") + : matrixC; + } + unsigned matrixLength = cast(matrixC->getType())->getNumElements(); + Type *accumType = FixedVectorType::get(builder.getHalfTy(), matrixLength * 2); + matrixC = builder.CreateBitCast(matrixC, accumType); + } else { + llvm_unreachable("Accumulator element type is not supported!"); + } + + if (factorElemType == Builder::CooperativeMatrixElementType::Float16 && + accumElemType == Builder::CooperativeMatrixElementType::Float32) { + matrixD = builder.CreateIntrinsic(matrixC->getType(), Intrinsic::amdgcn_wmma_f32_16x16x16_f16, + {matrixA, matrixB, matrixC}, nullptr, instName); + + } else if (factorElemType == Builder::CooperativeMatrixElementType::Int8 && + accumElemType == Builder::CooperativeMatrixElementType::Int32) { + matrixD = builder.CreateIntrinsic( + matrixC->getType(), Intrinsic::amdgcn_wmma_i32_16x16x16_iu8, + {builder.getInt1(isSignedA), matrixA, builder.getInt1(isSignedB), matrixB, matrixC, builder.getInt1(isSat)}, + nullptr, instName); + + } else if (factorElemType == Builder::CooperativeMatrixElementType::Float16 && + accumElemType == Builder::CooperativeMatrixElementType::Float16) { + // Matrix convert to match intrinsic arguments: Wave32: float32*v8->half*v16 + // Wave64: float32*v4->half*v8 + matrixD = builder.CreateIntrinsic(matrixC->getType(), Intrinsic::amdgcn_wmma_f16_16x16x16_f16, + {matrixA, matrixB, matrixC, builder.getInt1(isSat)}, nullptr, instName); + } else { + llvm_unreachable("The accumulator type is not supported."); + } + + if (accumElemType == Builder::CooperativeMatrixElementType::Float16 || + accumElemType == Builder::CooperativeMatrixElementType::Int16) { + unsigned coopVeclength = cast(matrixD->getType())->getNumElements(); + Type *wordTy = builder.transCooperativeMatrixElementType(accumElemType)->isIntOrIntVectorTy() + ? builder.getInt32Ty() + : builder.getFloatTy(); + matrixD = builder.CreateBitCast(matrixD, FixedVectorType::get(wordTy, coopVeclength / 2)); + { + matrixD = waveSize == 64 ? builder.CreateShuffleVector(matrixD, PoisonValue::get(matrixD->getType()), + ArrayRef{0, 1, 2, 3, 4, 5, 6, 7}) + : matrixD; + } + } else { + matrixD = waveSize == 64 ? builder.CreateShuffleVector(matrixD, PoisonValue::get(matrixD->getType()), + ArrayRef{0, 1, 2, 3, 4, 5, 6, 7}) + : matrixD; + } + return matrixD; + } else { // Emulator on NAVI2X + + Type *packedTy = (factorElemType == Builder::CooperativeMatrixElementType::Float16) ? builder.getFloatTy() + : builder.getInt32Ty(); + Value *dotProductValue; + + Value *threadId = getLaneNumber(builder); + Value *laneGroupIdx = builder.CreateUDiv(threadId, builder.getInt32(16)); + Value *isEvenGroup = + builder.CreateICmpEQ(builder.CreateAnd(laneGroupIdx, builder.getInt32(1)), builder.getInt32(0)); + + unsigned flags = (isSignedB << 1) | isSignedA; + auto mapFuncReadLane = [](BuilderBase &builder, ArrayRef mappedArgs, + ArrayRef passthroughArgs) -> Value * { + Type *const int32Ty = builder.getInt32Ty(); + + return builder.CreateIntrinsic(int32Ty, Intrinsic::amdgcn_readlane, {mappedArgs[0], passthroughArgs[0]}); + }; + + // matrixC is not reshaped for gfx10 + if (accumElemType == Builder::CooperativeMatrixElementType::Float32 || + accumElemType == Builder::CooperativeMatrixElementType::Int32) { + dotProductValue = PoisonValue::get(FixedVectorType::get(packedTy, 8)); + for (unsigned idxc = 0; idxc < 8; ++idxc) { + Value *rowlowgroup = builder.CreateMapToSimpleType(mapFuncReadLane, matrixA, builder.getInt32(idxc * 2)); + Value *rowhighgroup = builder.CreateMapToSimpleType(mapFuncReadLane, matrixA, builder.getInt32(idxc * 2 + 1)); + Value *rowData = builder.CreateSelect(isEvenGroup, rowlowgroup, rowhighgroup); + Value *mulAB; + Value *initAccumulator = builder.CreateExtractElement(matrixC, idxc); + if (factorElemType == Builder::CooperativeMatrixElementType::Float16) { + mulAB = createDotProductFp16Fp32(rowData, matrixB, initAccumulator, isSat, instName, insertPos); + } else if (factorElemType == Builder::CooperativeMatrixElementType::Int16) { + mulAB = createDotProductInt16Int32(rowData, matrixB, initAccumulator, flags, isSat, instName, insertPos); + } else if (factorElemType == Builder::CooperativeMatrixElementType::Int8) { + mulAB = createDotProductInt8Int32(rowData, matrixB, initAccumulator, flags, isSat, instName, insertPos); + } else { + llvm_unreachable("Unsupported element type!"); + } + dotProductValue = builder.CreateInsertElement(dotProductValue, mulAB, idxc); + } + } else if (accumElemType == Builder::CooperativeMatrixElementType::Int16 || + accumElemType == Builder::CooperativeMatrixElementType::Float16) { + dotProductValue = + PoisonValue::get(FixedVectorType::get(builder.transCooperativeMatrixElementType(accumElemType), 8)); + // For gfx10, A*B:8*float32->16*half C: no reshape for 16bit, still 16*half + Value *colData = convCoopMatrixVecToFlatVec(builder, matrixB, factorElemType, + Builder::CooperativeMatrixLayout::FactorMatrixLayout); + matrixC = convCoopMatrixVecToFlatVec(builder, matrixC, accumElemType, + Builder::CooperativeMatrixLayout::Gfx10Accumulator16bitMatrixLayout); + + for (unsigned idxc = 0, accIdx = 0; idxc < 16; idxc += 4, accIdx += 2) { + Value *rowData1Low = builder.CreateMapToSimpleType(mapFuncReadLane, matrixA, builder.getInt32(idxc)); + Value *rowData2Low = builder.CreateMapToSimpleType(mapFuncReadLane, matrixA, builder.getInt32(idxc + 1)); + Value *rowData1High = builder.CreateMapToSimpleType(mapFuncReadLane, matrixA, builder.getInt32(idxc + 2)); + Value *rowData2High = builder.CreateMapToSimpleType(mapFuncReadLane, matrixA, builder.getInt32(idxc + 3)); + + Value *rowData1 = builder.CreateSelect(isEvenGroup, rowData1Low, rowData1High); + Value *rowData2 = builder.CreateSelect(isEvenGroup, rowData2Low, rowData2High); + + rowData1 = convCoopMatrixVecToFlatVec(builder, rowData1, factorElemType, + Builder::CooperativeMatrixLayout::FactorMatrixLayout); + rowData2 = convCoopMatrixVecToFlatVec(builder, rowData2, factorElemType, + Builder::CooperativeMatrixLayout::FactorMatrixLayout); + + Value *mulAB1; + Value *mulAB2; + Value *accumulator1 = builder.CreateExtractElement(matrixC, accIdx); + Value *accumulator2 = builder.CreateExtractElement(matrixC, accIdx + 1); + + if (accumElemType == Builder::CooperativeMatrixElementType::Float16) { + mulAB1 = createDotProductFp16Fp16(rowData1, colData, accumulator1, isSat, instName, insertPos); + mulAB2 = createDotProductFp16Fp16(rowData2, colData, accumulator2, isSat, instName, insertPos); + } else { + mulAB1 = createDotProductInt16Int16(rowData1, colData, accumulator1, flags, isSat, instName, insertPos); + mulAB2 = createDotProductInt16Int16(rowData2, colData, accumulator2, flags, isSat, instName, insertPos); + } + dotProductValue = builder.CreateInsertElement(dotProductValue, mulAB1, accIdx); + dotProductValue = builder.CreateInsertElement(dotProductValue, mulAB2, accIdx + 1); + } + + dotProductValue = convFlatVecToCoopMatrixVec(builder, dotProductValue, accumElemType, + Builder::CooperativeMatrixLayout::Gfx10Accumulator16bitMatrixLayout); + } else { + llvm_unreachable("The accumulator type is not supported."); + } + return dotProductValue; + } +} + +// ===================================================================================================================== +// Create scalar from dot product of scalar or vector FP type. (The dot product of two scalars is their product.) +// The two vectors must be the same floating point scalar/vector type. +// Returns a value whose type is the element type of the vectors. +// +// @param vector1 : The float vector 1 +// @param vector2 : The float vector 2 +// @param initAccumulator : Initial accumulator +// @param isSat: SaturatingAccumulation for calculation +// @param instName : Name to give instruction(s) +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::createDotProductFp16Fp32(Value *const vector1, Value *const vector2, + Value *const initAccumulator, bool isSat, const Twine &instName, + Instruction *insertPos) { + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + + const unsigned compCount = cast(vector1->getType())->getNumElements(); + Value *scalar = initAccumulator; + auto intrinsicDot = Intrinsic::amdgcn_fdot2; + for (unsigned i = 0; i < compCount; ++i) { + Value *input1 = builder.CreateExtractElement(vector1, i); + input1 = builder.CreateBitCast(input1, FixedVectorType::get(builder.getHalfTy(), 2)); + Value *input2 = builder.CreateExtractElement(vector2, i); + input2 = builder.CreateBitCast(input2, FixedVectorType::get(builder.getHalfTy(), 2)); + scalar = + builder.CreateIntrinsic(intrinsicDot, {}, {input1, input2, scalar, builder.getInt1(isSat)}, nullptr, instName); + } + scalar->setName(instName); + return scalar; +} + +// ===================================================================================================================== +// Create scalar from dot product of scalar or vector FP type. (The dot product of two scalars is their product.) +// +// @param vector1 : The float vector 1 +// @param vector2 : The float vector 2 +// @param initAccumulator : Initial accumulator +// @param isSat: SaturatingAccumulation for calculation +// @param instName : Name to give instruction(s) +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::createDotProductFp16Fp16(Value *const vector1, Value *const vector2, + Value *const initAccumulator, bool isSat, const Twine &instName, + Instruction *insertPos) { + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + + Value *product = builder.CreateFMul(vector1, vector2); + if (!isa(product->getType())) + return product; + + const unsigned compCount = cast(product->getType())->getNumElements(); + Value *scalar = initAccumulator; + + for (unsigned i = 0; i < compCount; ++i) + scalar = builder.CreateFAdd(scalar, builder.CreateExtractElement(product, i)); + + scalar->setName(instName); + return scalar; +} + +// ===================================================================================================================== +// Create code to calculate the dot product of two integer vectors, with optional accumulator, using hardware support +// where available. +// Use a value of 0 for no accumulation and the value type is consistent with the result type. The result is saturated +// if there is an accumulator. The component type of input vectors can have 8-bit/16-bit/32-bit and i32/i16/i8 result. +// +// @param vector1 : The integer vector 1 +// @param vector2 : The integer vector 2 +// @param accumulator : The accumulator to the scalar of dot product +// @param flags : Bit 0 is "first vector is signed" and bit 1 is "second vector is signed" +// @param isSat: SaturatingAccumulation for calculation +// @param instName : Name to give instruction(s) +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::createDotProductInt8Int32(Value *vector1, Value *vector2, Value *accumulator, + unsigned flags, bool isSat, const Twine &instName, + Instruction *insertPos) { + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + + const bool isSigned = (flags & lgc::Builder::FirstVectorSigned); + auto intrinsicDot = isSigned ? Intrinsic::amdgcn_sdot4 : Intrinsic::amdgcn_udot4; + + Value *scalar = builder.getInt32(0); + const unsigned compCount = cast(vector1->getType())->getNumElements(); + for (unsigned i = 0; i < compCount; ++i) { + Value *input1 = builder.CreateExtractElement(vector1, i); + Value *input2 = builder.CreateExtractElement(vector2, i); + scalar = + builder.CreateIntrinsic(intrinsicDot, {}, {input1, input2, scalar, builder.getInt1(false)}, nullptr, instName); + } + + // Always use sadd_sat here as uint32@C is not supported. + scalar = builder.CreateSExt(scalar, builder.getInt32Ty()); + if (isSat) { + scalar = builder.CreateBinaryIntrinsic(Intrinsic::sadd_sat, scalar, accumulator, nullptr, instName); + } else { + scalar = builder.CreateAdd(scalar, accumulator, instName); + } + scalar->setName(instName); + return scalar; +} + +// ===================================================================================================================== +// Create code to calculate the dot product of two integer vectors, with optional accumulator +// +// @param vector1 : The integer vector 1 +// @param vector2 : The integer vector 2 +// @param accumulator : The accumulator to the scalar of dot product +// @param flags : Bit 0 is "first vector is signed" and bit 1 is "second vector is signed" +// @param isSat: SaturatingAccumulation for calculation +// @param instName : Name to give instruction(s) +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::createDotProductInt16Int32(Value *vector1, Value *vector2, Value *accumulator, + unsigned flags, bool isSat, const Twine &instName, + Instruction *insertPos) { + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + + const bool isSigned = (flags & lgc::Builder::FirstVectorSigned); + auto intrinsicDot = isSigned ? Intrinsic::amdgcn_sdot2 : Intrinsic::amdgcn_udot2; + + Value *scalar = accumulator; + const unsigned compCount = cast(vector1->getType())->getNumElements(); + for (unsigned i = 0; i < compCount; ++i) { + Value *input1 = builder.CreateExtractElement(vector1, i); + input1 = builder.CreateBitCast(input1, FixedVectorType::get(builder.getInt16Ty(), 2)); + Value *input2 = builder.CreateExtractElement(vector2, i); + input2 = builder.CreateBitCast(input2, FixedVectorType::get(builder.getInt16Ty(), 2)); + scalar = + builder.CreateIntrinsic(intrinsicDot, {}, {input1, input2, scalar, builder.getInt1(isSat)}, nullptr, instName); + } + scalar->setName(instName); + return scalar; +} + +// ===================================================================================================================== +// Create code to calculate the dot product of two integer vectors, with optional accumulator +// +// @param vector1 : The integer vector 1 +// @param vector2 : The integer vector 2 +// @param accumulator : The accumulator to the scalar of dot product +// @param flags : Bit 0 is "first vector is signed" and bit 1 is "second vector is signed" +// @param isSat: SaturatingAccumulation for calculation +// @param instName : Name to give instruction(s) +// @param insertPos : Where to insert the instruction +Value *LowerCooperativeMatrix::createDotProductInt16Int16(Value *vector1, Value *vector2, Value *accumulator, + unsigned flags, bool isSat, const Twine &instName, + Instruction *insertPos) { + BuilderBase builder(*m_context); + builder.SetInsertPoint(insertPos); + Type *inputTy = vector1->getType(); + assert(inputTy->isVectorTy() && inputTy->getScalarType()->isIntegerTy()); + + const unsigned compCount = cast(inputTy)->getNumElements(); + Type *outputTy = accumulator->getType(); + // The component of Vector 2 can be signed or unsigned + const bool isSigned = (flags & lgc::Builder::FirstVectorSigned); + // The mixed signed/unsigned is that component of Vector 1 is treated as signed and component of Vector 2 is treated + // as unsigned. + const bool isMixed = (flags == lgc::Builder::FirstVectorSigned); + + Type *targetTy = builder.getInt64Ty(); + // Emulate dot product with no HW support cases + Value *scalar = builder.getInt64(0); + for (unsigned elemIdx = 0; elemIdx < compCount; ++elemIdx) { + Value *elem1 = builder.CreateExtractElement(vector1, elemIdx); + elem1 = isSigned ? builder.CreateSExt(elem1, targetTy) : builder.CreateZExt(elem1, targetTy); + Value *elem2 = builder.CreateExtractElement(vector2, elemIdx); + elem2 = (isSigned && !isMixed) ? builder.CreateSExt(elem2, targetTy) : builder.CreateZExt(elem2, targetTy); + Value *product = builder.CreateMul(elem1, elem2); + scalar = builder.CreateAdd(product, scalar); + } + + scalar = builder.CreateTrunc(scalar, builder.getInt32Ty()); + accumulator = builder.CreateTrunc(accumulator, builder.getInt32Ty()); + Intrinsic::ID addIntrinsic = isSigned ? Intrinsic::sadd_sat : Intrinsic::uadd_sat; + scalar = builder.CreateBinaryIntrinsic(addIntrinsic, scalar, accumulator, nullptr, instName); + + const unsigned bitWidth = outputTy->getScalarSizeInBits(); + auto unsignedMax = (2ULL << (bitWidth - 1)) - 1; + auto signedMax = unsignedMax >> 1; + auto signedMin = -1ULL - signedMax; + + Value *minimum = nullptr, *maximum = nullptr; + Value *isUnderflow = nullptr, *isOverflow = nullptr; + if (isSigned) { + scalar = builder.CreateSExt(scalar, builder.getInt64Ty()); + minimum = ConstantInt::getSigned(builder.getInt64Ty(), signedMin); + maximum = ConstantInt::getSigned(builder.getInt64Ty(), signedMax); + isUnderflow = builder.CreateICmpSLT(scalar, minimum); + isOverflow = builder.CreateICmpSGT(scalar, maximum); + } else { + scalar = builder.CreateZExt(scalar, builder.getInt64Ty()); + minimum = builder.getInt64(0); + maximum = builder.getInt64(unsignedMax); + isUnderflow = builder.CreateICmpULT(scalar, minimum); + isOverflow = builder.CreateICmpUGT(scalar, maximum); + } + scalar = builder.CreateSelect(isUnderflow, minimum, scalar); + scalar = builder.CreateSelect(isOverflow, maximum, scalar); + scalar = builder.CreateTrunc(scalar, outputTy); + + scalar->setName(instName); + return scalar; +} + +// ===================================================================================================================== +// Get lane id. +// @param builder : The IR builder to create and insert IR instruction +Value *LowerCooperativeMatrix::getLaneNumber(BuilderBase &builder) { + Value *result = builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {}, {builder.getInt32(-1), builder.getInt32(0)}); + if (m_pipelineState->getShaderWaveSize(m_shaderStage) == 64) + result = builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_hi, {}, {builder.getInt32(-1), result}); + return result; +} + +} // namespace lgc diff --git a/lgc/patch/PassRegistry.inc b/lgc/patch/PassRegistry.inc index 184dffe248..6b1f2da15a 100644 --- a/lgc/patch/PassRegistry.inc +++ b/lgc/patch/PassRegistry.inc @@ -79,6 +79,9 @@ LLPC_MODULE_PASS("lgc-vertex-fetch", LowerVertexFetch) LLPC_MODULE_PASS("lgc-frag-color-export", LowerFragColorExport) LLPC_MODULE_PASS("lgc-lower-debug-printf", LowerDebugPrintf) +LLPC_FUNCTION_PASS("lgc-combine-cooperative-matrix", CombineCooperativeMatrix) +LLPC_MODULE_PASS("lgc-lower-cooperative-matrix", LowerCooperativeMatrix) + #undef LLPC_PASS #undef LLPC_LOOP_PASS #undef LLPC_FUNCTION_PASS diff --git a/lgc/patch/Patch.cpp b/lgc/patch/Patch.cpp index dd2831a8fd..fc55926a1f 100644 --- a/lgc/patch/Patch.cpp +++ b/lgc/patch/Patch.cpp @@ -35,8 +35,10 @@ #include "lgc/PassManager.h" #include "lgc/Pipeline.h" #include "lgc/builder/BuilderReplayer.h" +#include "lgc/patch/CombineCooperativeMatrix.h" #include "lgc/patch/Continufy.h" #include "lgc/patch/FragColorExport.h" +#include "lgc/patch/LowerCooperativeMatrix.h" #include "lgc/patch/LowerDebugPrintf.h" #include "lgc/patch/PatchBufferOp.h" #include "lgc/patch/PatchCheckShaderCache.h" @@ -153,6 +155,10 @@ void Patch::addPasses(PipelineState *pipelineState, lgc::PassManager &passMgr, T passMgr.addPass(IPSCCPPass()); passMgr.addPass(LowerDebugPrintf()); + passMgr.addPass(createModuleToFunctionPassAdaptor(CombineCooperativeMatrix())); + // Lower the cooperative matrix + passMgr.addPass(LowerCooperativeMatrix()); + if (pipelineState->hasShaderStage(ShaderStageVertex) && !pipelineState->hasShaderStage(ShaderStageTessControl) && pipelineState->hasShaderStage(ShaderStageTessEval)) passMgr.addPass(TcsPassthroughShader()); diff --git a/lgc/patch/PatchBufferOp.cpp b/lgc/patch/PatchBufferOp.cpp index ab5555954a..58baf82805 100644 --- a/lgc/patch/PatchBufferOp.cpp +++ b/lgc/patch/PatchBufferOp.cpp @@ -35,7 +35,6 @@ #include "lgc/state/IntrinsDefs.h" #include "lgc/state/PipelineState.h" #include "lgc/state/TargetInfo.h" -#include "lgc/util/TypeLowering.h" #include "llvm-dialects/Dialect/Visitor.h" #include "llvm/ADT/PostOrderIterator.h" #if LLVM_MAIN_REVISION && LLVM_MAIN_REVISION < 458033 diff --git a/lgc/patch/PatchCopyShader.cpp b/lgc/patch/PatchCopyShader.cpp index 94e6e66376..4829de1a27 100644 --- a/lgc/patch/PatchCopyShader.cpp +++ b/lgc/patch/PatchCopyShader.cpp @@ -121,7 +121,9 @@ bool PatchCopyShader::runImpl(Module &module, PipelineShadersResult &pipelineSha // i32 inreg streamOutOffset3, // i32 vertexOffset) // + argTys = {int32Ty, int32Ty, int32Ty, int32Ty, int32Ty, int32Ty, int32Ty, int32Ty, int32Ty, int32Ty}; + argInReg = {true, true, true, true, true, true, true, true, true, false}; // clang-format off argNames = {"globalTable", diff --git a/lgc/patch/PatchEntryPointMutate.cpp b/lgc/patch/PatchEntryPointMutate.cpp index 96b5899e72..bdc8421e46 100644 --- a/lgc/patch/PatchEntryPointMutate.cpp +++ b/lgc/patch/PatchEntryPointMutate.cpp @@ -68,7 +68,6 @@ #include "lgc/state/TargetInfo.h" #include "lgc/util/AddressExtender.h" #include "lgc/util/BuilderBase.h" -#include "lgc/util/CpsStackLowering.h" #include "llvm-dialects/Dialect/Visitor.h" #include "llvm/Analysis/AliasAnalysis.h" // for MemoryEffects #include "llvm/IR/IntrinsicsAMDGPU.h" @@ -143,6 +142,8 @@ bool PatchEntryPointMutate::runImpl(Module &module, PipelineShadersResult &pipel m_pipelineState = pipelineState; + stackLowering = std::make_unique(module.getContext(), ADDR_SPACE_PRIVATE); + const unsigned stageMask = m_pipelineState->getShaderStageMask(); m_hasTs = (stageMask & (shaderStageToMask(ShaderStageTessControl) | shaderStageToMask(ShaderStageTessEval))) != 0; m_hasGs = (stageMask & shaderStageToMask(ShaderStageGeometry)) != 0; @@ -515,7 +516,8 @@ bool PatchEntryPointMutate::lowerCpsOps(Function *func, ShaderInputs *shaderInpu if (!isCpsFunc) { IRBuilder<> builder(func->getContext()); builder.SetInsertPointPastAllocas(func); - Value *vspStorage = builder.CreateAlloca(builder.getPtrTy(getLoweredCpsStackAddrSpace()), ADDR_SPACE_PRIVATE); + Value *vspStorage = + builder.CreateAlloca(builder.getPtrTy(stackLowering->getLoweredCpsStackAddrSpace()), ADDR_SPACE_PRIVATE); m_funcCpsStackMap[func] = vspStorage; } @@ -554,7 +556,7 @@ bool PatchEntryPointMutate::lowerCpsOps(Function *func, ShaderInputs *shaderInpu // Lower returns. for (auto *ret : retInstrs) { - auto *vspTy = builder.getPtrTy(getLoweredCpsStackAddrSpace()); + auto *vspTy = builder.getPtrTy(stackLowering->getLoweredCpsStackAddrSpace()); exitInfos.push_back(CpsExitInfo(ret->getParent(), {builder.getInt32(0), PoisonValue::get(vspTy)})); builder.SetInsertPoint(ret); builder.CreateBr(tailBlock); @@ -686,10 +688,9 @@ bool PatchEntryPointMutate::lowerCpsOps(Function *func, ShaderInputs *shaderInpu auto funcName = func->getName(); // Lower cps stack operations - CpsStackLowering stackLowering(func->getContext()); - stackLowering.lowerCpsStackOps(*func, m_funcCpsStackMap[func]); + stackLowering->lowerCpsStackOps(*func, m_funcCpsStackMap[func]); - stackSize += stackLowering.getStackSize(); + stackSize += stackLowering->getStackSizeInBytes(); // Set per-function .frontend_stack_size PAL metadata. auto &shaderFunctions = m_pipelineState->getPalMetadata() ->getPipelineNode() @@ -719,7 +720,7 @@ Function *PatchEntryPointMutate::lowerCpsFunction(Function *func, ArrayRef builder(func->getContext()); SmallVector newArgTys; newArgTys.append(fixedShaderArgTys.begin(), fixedShaderArgTys.end()); - newArgTys.append({builder.getInt32Ty(), builder.getPtrTy(getLoweredCpsStackAddrSpace())}); + newArgTys.append({builder.getInt32Ty(), builder.getPtrTy(stackLowering->getLoweredCpsStackAddrSpace())}); auto remainingArgs = func->getFunctionType()->params().drop_front(1); newArgTys.append(remainingArgs.begin(), remainingArgs.end()); FunctionType *newFuncTy = FunctionType::get(builder.getVoidTy(), newArgTys, false); @@ -758,7 +759,8 @@ Function *PatchEntryPointMutate::lowerCpsFunction(Function *func, ArrayRefsplice(newFunc->begin(), func); builder.SetInsertPointPastAllocas(newFunc); - Value *vspStorage = builder.CreateAlloca(builder.getPtrTy(getLoweredCpsStackAddrSpace()), ADDR_SPACE_PRIVATE); + Value *vspStorage = + builder.CreateAlloca(builder.getPtrTy(stackLowering->getLoweredCpsStackAddrSpace()), ADDR_SPACE_PRIVATE); m_funcCpsStackMap[newFunc] = vspStorage; // Function arguments: {fixed_shader_arguments, vcr, vsp, original_func_arguments_exclude_state} @@ -766,8 +768,8 @@ Function *PatchEntryPointMutate::lowerCpsFunction(Function *func, ArrayRefgetType()->isEmptyTy()) { // Get stack address of pushed state and load it from continuation stack. unsigned stateSize = layout.getTypeStoreSize(state->getType()); - vsp = builder.CreateConstInBoundsGEP1_32(builder.getInt8Ty(), vsp, -alignTo(stateSize, continuationStackAlignment)); - Value *newState = builder.CreateAlignedLoad(state->getType(), vsp, Align(continuationStackAlignment), "cps.state"); + vsp = builder.CreateConstInBoundsGEP1_32(builder.getInt8Ty(), vsp, -alignTo(stateSize, ContinuationStackAlignment)); + Value *newState = builder.CreateAlignedLoad(state->getType(), vsp, Align(ContinuationStackAlignment), "cps.state"); state->replaceAllUsesWith(newState); } builder.CreateStore(vsp, vspStorage); @@ -816,14 +818,15 @@ unsigned PatchEntryPointMutate::lowerCpsJump(Function *parent, cps::JumpOp *jump // Pushing state onto stack and get new vsp. Value *state = jumpOp->getState(); - Value *vsp = builder.CreateAlignedLoad(builder.getPtrTy(getLoweredCpsStackAddrSpace()), m_funcCpsStackMap[parent], - Align(getLoweredCpsStackPointerSize(layout))); + Value *vsp = + builder.CreateAlignedLoad(builder.getPtrTy(stackLowering->getLoweredCpsStackAddrSpace()), + m_funcCpsStackMap[parent], Align(stackLowering->getLoweredCpsStackPointerSize(layout))); unsigned stateSize = 0; if (!state->getType()->isEmptyTy()) { stateSize = layout.getTypeStoreSize(state->getType()); builder.CreateStore(state, vsp); // Make vsp properly aligned across cps function. - stateSize = alignTo(stateSize, continuationStackAlignment); + stateSize = alignTo(stateSize, ContinuationStackAlignment); vsp = builder.CreateConstGEP1_32(builder.getInt8Ty(), vsp, stateSize); } diff --git a/lgc/patch/PatchResourceCollect.cpp b/lgc/patch/PatchResourceCollect.cpp index 1b3bc9a9ec..d49c2ed82d 100644 --- a/lgc/patch/PatchResourceCollect.cpp +++ b/lgc/patch/PatchResourceCollect.cpp @@ -3309,11 +3309,12 @@ void PatchResourceCollect::scalarizeForInOutPacking(Module *module) { payload.inputCalls.push_back(&input); } }; - static auto visitor = llvm_dialects::VisitorBuilder() - .setStrategy(llvm_dialects::VisitorStrategy::ByFunctionDeclaration) - .add([](auto &payload, auto &op) { visitInput(payload, op); }) - .add([](auto &payload, auto &op) { visitInput(payload, op); }) - .build(); + static auto visitor = + llvm_dialects::VisitorBuilder() + .setStrategy(llvm_dialects::VisitorStrategy::ByFunctionDeclaration) + .addSet( + [](Payload &payload, Instruction &op) { visitInput(payload, cast(op)); }) + .build(); visitor.visit(payload, *module); for (Function &func : *module) { diff --git a/lgc/patch/RegisterMetadataBuilder.cpp b/lgc/patch/RegisterMetadataBuilder.cpp index 08f6fc3389..8a1990f4bf 100644 --- a/lgc/patch/RegisterMetadataBuilder.cpp +++ b/lgc/patch/RegisterMetadataBuilder.cpp @@ -1068,8 +1068,6 @@ void RegisterMetadataBuilder::buildShaderExecutionRegisters(Util::Abi::HardwareS unsigned sgprLimits = 0; unsigned vgprLimits = 0; if (apiStage1 == ShaderStageCopyShader) { - // NOTE: For copy shader, usually we use fixed number of user data registers. - // But in some cases, we may change user data registers, we use variable to keep user sgpr count here userDataCount = lgc::CopyShaderUserSgprCount; sgprLimits = m_pipelineState->getTargetInfo().getGpuProperty().maxSgprsAvailable; vgprLimits = m_pipelineState->getTargetInfo().getGpuProperty().maxVgprsAvailable; diff --git a/lgc/patch/VertexFetch.cpp b/lgc/patch/VertexFetch.cpp index c8242c7bba..92acdf8019 100644 --- a/lgc/patch/VertexFetch.cpp +++ b/lgc/patch/VertexFetch.cpp @@ -61,7 +61,7 @@ namespace { // Map vkgc static constexpr unsigned InternalDescriptorSetId = static_cast(-1); static constexpr unsigned FetchShaderInternalBufferBinding = 5; // Descriptor binding for uber fetch shader -static constexpr unsigned CurrentAttributeBufferBinding = 24; // Descriptor binding for current attribute +static constexpr unsigned CurrentAttributeBufferBinding = 1; // Descriptor binding for current attribute static constexpr unsigned GenericVertexFetchShaderBinding = 0; // Descriptor binding for generic vertex fetch shader static constexpr unsigned VertexInputBindingCurrent = 64; // Vertex input binding for current attribute @@ -175,6 +175,23 @@ const VertexCompFormatInfo VertexFetchImpl::m_vertexCompFormatInfo[] = { {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormatReserved {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat8_8_8_8_Bgra {3, 1, 3, BUF_DATA_FORMAT_8}, // BufDataFormat8_8_8 + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat8_8_8_Bgr, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat2_10_10_10_Bgra, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat64, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat64_64, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat64_64_64, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat64_64_64_64, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat4_4, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat4_4_4_4, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat4_4_4_4_Bgra, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat5_6_5, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat5_6_5_Bgr, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat5_6_5_1, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat5_6_5_1_Bgra, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat1_5_6_5, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat5_9_9_9, + {0, 0, 0, BUF_DATA_FORMAT_INVALID}, // BufDataFormat8_A, + {6, 2, 3, BUF_DATA_FORMAT_16}, // BufDataFormat16_16_16 }; // clang-format off @@ -1171,9 +1188,9 @@ Value *VertexFetchImpl::fetchVertex(Type *inputTy, const VertexInputDescription } vbIndex = m_vertexIndex; } else { - if (description->inputRate == VertexInputRateNone) { + if (description->divisor == 0) { vbIndex = ShaderInputs::getSpecialUserData(UserDataMapping::BaseInstance, builder); - } else if (description->inputRate == VertexInputRateInstance) { + } else if (description->divisor == 1) { // Use instance index if (!m_instanceIndex) { IRBuilder<>::InsertPointGuard guard(builder); @@ -1184,7 +1201,7 @@ Value *VertexFetchImpl::fetchVertex(Type *inputTy, const VertexInputDescription } else { // There is a divisor. vbIndex = builder.CreateUDiv(ShaderInputs::getInput(ShaderInput::InstanceId, builder, *m_lgcContext), - builder.getInt32(description->inputRate)); + builder.getInt32(description->divisor)); vbIndex = builder.CreateAdd(vbIndex, ShaderInputs::getSpecialUserData(UserDataMapping::BaseInstance, builder)); } } @@ -1475,6 +1492,10 @@ VertexFormatInfo VertexFetchImpl::getVertexFormatInfo(const VertexInputDescripti info.dfmt = BufDataFormat8_8_8; info.numChannels = 3; break; + case BufDataFormat16_16_16: + info.dfmt = BufDataFormat16_16_16; + info.numChannels = 3; + break; default: break; } @@ -1529,13 +1550,16 @@ unsigned VertexFetchImpl::mapVertexFormat(unsigned dfmt, unsigned nfmt) const { // @param binding : ID of vertex buffer binding // @param builder : Builder with insert point set Value *VertexFetchImpl::loadVertexBufferDescriptor(unsigned binding, BuilderImpl &builderImpl) { + BuilderBase &builder = BuilderBase::get(builderImpl); if (builderImpl.useVertexBufferDescArray()) { Value *vtxDesc = nullptr; // Create descriptor for current attribute if (binding == VertexInputBindingCurrent) { if (m_curAttribBufferDescr == nullptr) { + IRBuilder<>::InsertPointGuard guard(builder); + builder.SetInsertPointPastAllocas(builder.GetInsertBlock()->getParent()); auto descPtr = builderImpl.CreateBufferDesc(InternalDescriptorSetId, CurrentAttributeBufferBinding, - builderImpl.getInt32(0), Builder::BufferFlagAddress); + builderImpl.getInt32(0), lgc::Builder::BufferFlagAddress); // Create descriptor by a 64-bits pointer m_curAttribBufferDescr = builderImpl.buildInlineBufferDesc(descPtr); } @@ -1543,13 +1567,12 @@ Value *VertexFetchImpl::loadVertexBufferDescriptor(unsigned binding, BuilderImpl } else { // Create descriptor for vertex buffer vtxDesc = builderImpl.CreateBufferDesc(InternalDescriptorSetId, GenericVertexFetchShaderBinding, - builderImpl.getInt32(binding), Builder::BufferFlagNonConst); + builderImpl.getInt32(binding), lgc::Builder::BufferFlagNonConst); } return vtxDesc; } - BuilderBase &builder = BuilderBase::get(builderImpl); // Get the vertex buffer table pointer as pointer to v4i32 descriptor. Type *vbDescTy = FixedVectorType::get(Type::getInt32Ty(*m_context), 4); if (!m_vertexBufTablePtr) { @@ -1595,7 +1618,7 @@ void VertexFetchImpl::addVertexFetchInst(Value *vbDesc, unsigned numChannels, bo // NOTE: For the vertex data format 8_8, 8_8_8_8, 16_16, and 16_16_16_16, tbuffer_load has a HW defect when // vertex buffer is unaligned. Therefore, we have to split the vertex fetch to component-based ones dfmt != BufDataFormat8_8 && dfmt != BufDataFormat8_8_8_8 && dfmt != BufDataFormat16_16 && - dfmt != BufDataFormat16_16_16_16 && dfmt != BufDataFormat8_8_8) || + dfmt != BufDataFormat16_16_16_16 && dfmt != BufDataFormat8_8_8 && dfmt != BufDataFormat16_16_16) || formatInfo->compDfmt == dfmt) { // Do vertex fetch Value *args[] = { diff --git a/lgc/state/PipelineState.cpp b/lgc/state/PipelineState.cpp index 6833b027fa..25fb5391e5 100644 --- a/lgc/state/PipelineState.cpp +++ b/lgc/state/PipelineState.cpp @@ -111,6 +111,7 @@ static unsigned getMaxComponentBitCount(BufDataFormat dfmt) { return 11; case BufDataFormat16: case BufDataFormat16_16: + case BufDataFormat16_16_16: case BufDataFormat16_16_16_16: return 16; case BufDataFormat32: @@ -178,6 +179,7 @@ static unsigned getNumChannels(BufDataFormat dfmt) { case BufDataFormat8_8_8_Bgr: case BufDataFormat10_11_11: case BufDataFormat11_11_10: + case BufDataFormat16_16_16: case BufDataFormat32_32_32: case BufDataFormat64_64_64: case BufDataFormat5_6_5: diff --git a/lgc/test/Transforms/CombineCooperativeMatrix/constants.lgc b/lgc/test/Transforms/CombineCooperativeMatrix/constants.lgc new file mode 100644 index 0000000000..6fcd242e9e --- /dev/null +++ b/lgc/test/Transforms/CombineCooperativeMatrix/constants.lgc @@ -0,0 +1,154 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool lgc +; RUN: lgc -o - -passes=lgc-combine-cooperative-matrix %s | FileCheck --check-prefixes=CHECK %s + +define <8 x float> @transpose_undef() { +; CHECK-LABEL: @transpose_undef( +; CHECK-NEXT: ret <8 x float> undef +; + %r = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> undef, i32 1, i32 0) + ret <8 x float> %r +} + +define <8 x float> @transpose_poison() { +; CHECK-LABEL: @transpose_poison( +; CHECK-NEXT: ret <8 x float> poison +; + %r = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> poison, i32 1, i32 0) + ret <8 x float> %r +} + +define <8 x float> @transpose_zero() { +; CHECK-LABEL: @transpose_zero( +; CHECK-NEXT: ret <8 x float> zeroinitializer +; + %r = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> zeroinitializer, i32 1, i32 0) + ret <8 x float> %r +} + +define <8 x float> @relayout_undef() { +; CHECK-LABEL: @relayout_undef( +; CHECK-NEXT: ret <8 x float> undef +; + %r = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> undef, i32 1, i32 1, i32 0, i32 1) + ret <8 x float> %r +} + +define <8 x float> @relayout_poison() { +; CHECK-LABEL: @relayout_poison( +; CHECK-NEXT: ret <8 x float> poison +; + %r = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> poison, i32 1, i32 1, i32 0, i32 1) + ret <8 x float> %r +} + +define <8 x float> @relayout_zero() { +; CHECK-LABEL: @relayout_zero( +; CHECK-NEXT: ret <8 x float> zeroinitializer +; + %r = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> zeroinitializer, i32 1, i32 1, i32 0, i32 1) + ret <8 x float> %r +} + +define <8 x float> @fptrunc_undef() { +; CHECK-LABEL: @fptrunc_undef( +; CHECK-NEXT: [[R:%.*]] = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 45, <8 x float> undef, i32 2, i32 1, i32 0, i32 0) +; CHECK-NEXT: ret <8 x float> [[R]] +; + %r = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 45, <8 x float> undef, i32 2, i32 1, i32 0, i32 0) + ret <8 x float> %r +} + +define <8 x float> @fptrunc_poison() { +; CHECK-LABEL: @fptrunc_poison( +; CHECK-NEXT: ret <8 x float> poison +; + %r = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 45, <8 x float> poison, i32 2, i32 1, i32 0, i32 0) + ret <8 x float> %r +} + +define <8 x float> @fptrunc_zero() { +; CHECK-LABEL: @fptrunc_zero( +; CHECK-NEXT: ret <8 x float> zeroinitializer +; + %r = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 45, <8 x float> zeroinitializer, i32 2, i32 1, i32 0, i32 0) + ret <8 x float> %r +} + +define <8 x float> @fpext_undef() { +; CHECK-LABEL: @fpext_undef( +; CHECK-NEXT: [[R:%.*]] = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 46, <8 x float> undef, i32 1, i32 2, i32 0, i32 0) +; CHECK-NEXT: ret <8 x float> [[R]] +; + %r = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 46, <8 x float> undef, i32 1, i32 2, i32 0, i32 0) + ret <8 x float> %r +} + +define <8 x float> @fpext_poison() { +; CHECK-LABEL: @fpext_poison( +; CHECK-NEXT: ret <8 x float> poison +; + %r = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 46, <8 x float> poison, i32 1, i32 2, i32 0, i32 0) + ret <8 x float> %r +} + +define <8 x float> @fpext_zero() { +; CHECK-LABEL: @fpext_zero( +; CHECK-NEXT: ret <8 x float> zeroinitializer +; + %r = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 46, <8 x float> zeroinitializer, i32 1, i32 2, i32 0, i32 0) + ret <8 x float> %r +} + +define <8 x i32> @trunc_undef() { +; CHECK-LABEL: @trunc_undef( +; CHECK-NEXT: [[R:%.*]] = call <8 x i32> @lgc.cooperative.matrix.convert.v8i32.i32.v8i32.i32.i32.i32.i32(i32 38, <8 x i32> undef, i32 5, i32 4, i32 0, i32 0) +; CHECK-NEXT: ret <8 x i32> [[R]] +; + %r = call <8 x i32> @lgc.cooperative.matrix.convert.v8i32.i32.v8i32.i32.i32.i32.i32(i32 38, <8 x i32> undef, i32 5, i32 4, i32 0, i32 0) + ret <8 x i32> %r +} + +define <8 x i32> @trunc_poison() { +; CHECK-LABEL: @trunc_poison( +; CHECK-NEXT: ret <8 x i32> poison +; + %r = call <8 x i32> @lgc.cooperative.matrix.convert.v8i32.i32.v8i32.i32.i32.i32.i32(i32 38, <8 x i32> poison, i32 5, i32 4, i32 0, i32 0) + ret <8 x i32> %r +} + +define <8 x i32> @trunc_zero() { +; CHECK-LABEL: @trunc_zero( +; CHECK-NEXT: ret <8 x i32> zeroinitializer +; + %r = call <8 x i32> @lgc.cooperative.matrix.convert.v8i32.i32.v8i32.i32.i32.i32.i32(i32 38, <8 x i32> zeroinitializer, i32 5, i32 4, i32 0, i32 0) + ret <8 x i32> %r +} + +define <8 x i32> @zext_undef() { +; CHECK-LABEL: @zext_undef( +; CHECK-NEXT: [[R:%.*]] = call <8 x i32> @lgc.cooperative.matrix.convert.v8i32.i32.v8i32.i32.i32.i32.i32(i32 39, <8 x i32> undef, i32 4, i32 5, i32 0, i32 0) +; CHECK-NEXT: ret <8 x i32> [[R]] +; + %r = call <8 x i32> @lgc.cooperative.matrix.convert.v8i32.i32.v8i32.i32.i32.i32.i32(i32 39, <8 x i32> undef, i32 4, i32 5, i32 0, i32 0) + ret <8 x i32> %r +} + +define <8 x i32> @zext_poison() { +; CHECK-LABEL: @zext_poison( +; CHECK-NEXT: ret <8 x i32> poison +; + %r = call <8 x i32> @lgc.cooperative.matrix.convert.v8i32.i32.v8i32.i32.i32.i32.i32(i32 39, <8 x i32> poison, i32 4, i32 5, i32 0, i32 0) + ret <8 x i32> %r +} + +define <8 x i32> @zext_zero() { +; CHECK-LABEL: @zext_zero( +; CHECK-NEXT: ret <8 x i32> zeroinitializer +; + %r = call <8 x i32> @lgc.cooperative.matrix.convert.v8i32.i32.v8i32.i32.i32.i32.i32(i32 39, <8 x i32> zeroinitializer, i32 4, i32 5, i32 0, i32 0) + ret <8 x i32> %r +} + +declare <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float>, i32, i32) +declare <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32, <8 x float>, i32, i32, i32, i32) +declare <8 x i32> @lgc.cooperative.matrix.convert.v8i32.i32.v8i32.i32.i32.i32.i32(i32, <8 x i32>, i32, i32, i32, i32) diff --git a/lgc/test/Transforms/CombineCooperativeMatrix/lit.local.cfg b/lgc/test/Transforms/CombineCooperativeMatrix/lit.local.cfg new file mode 100644 index 0000000000..a4266bc874 --- /dev/null +++ b/lgc/test/Transforms/CombineCooperativeMatrix/lit.local.cfg @@ -0,0 +1,2 @@ +if "vki_cooperative_matrix" not in config.available_features: + config.unsupported = True diff --git a/lgc/test/Transforms/CombineCooperativeMatrix/matmul-loop.lgc b/lgc/test/Transforms/CombineCooperativeMatrix/matmul-loop.lgc new file mode 100644 index 0000000000..95e479a498 --- /dev/null +++ b/lgc/test/Transforms/CombineCooperativeMatrix/matmul-loop.lgc @@ -0,0 +1,84 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool lgc +; RUN: lgc -o - -passes=lgc-combine-cooperative-matrix %s | FileCheck --check-prefixes=CHECK %s + +define void @matmul_f16(ptr %ptr) { +; CHECK-LABEL: @matmul_f16( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[ACCUM_LOAD:%.*]] = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p0.i32.i1.i32.i32.i32(ptr [[PTR:%.*]], i32 4, i1 false, i32 1, i32 1, i32 0) +; CHECK-NEXT: br label [[LOOP:%.*]] +; CHECK: loop: +; CHECK-NEXT: [[ACCUM_PHI:%.*]] = phi <8 x float> [ [[ACCUM_LOAD]], [[ENTRY:%.*]] ], [ [[MULADD:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[A:%.*]] = call <8 x float> @getmat1() +; CHECK-NEXT: [[B:%.*]] = call <8 x float> @getmat1() +; CHECK-NEXT: [[MULADD]] = call <8 x float> @lgc.cooperative.matrix.muladd.v8f32.v8f32.v8f32.v8f32.i1.i1.i32.i32(<8 x float> [[A]], <8 x float> [[B]], <8 x float> [[ACCUM_PHI]], i1 true, i1 true, i32 1, i32 1) +; CHECK-NEXT: [[CC:%.*]] = call i1 @getcc() +; CHECK-NEXT: br i1 [[CC]], label [[LOOP]], label [[END:%.*]] +; CHECK: end: +; CHECK-NEXT: call void @lgc.cooperative.matrix.store.p0.i32.i1.i32.i32.i32.v8f32(ptr [[PTR]], i32 4, i1 true, i32 1, i32 1, i32 0, <8 x float> [[MULADD]]) +; CHECK-NEXT: ret void +; +entry: + %accum.load = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p0.i32.i1.i32.i32.i32(ptr %ptr, i32 4, i1 false, i32 1, i32 0, i32 0) + br label %loop + +loop: + %accum.phi = phi <8 x float> [ %accum.load, %entry ], [ %accum.next, %loop ] + + %a = call <8 x float> @getmat1() + %b = call <8 x float> @getmat1() + + %accum.cvt = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %accum.phi, i32 1, i32 1, i32 0, i32 1) + %muladd = call <8 x float> @lgc.cooperative.matrix.muladd.v8f32.v8f32.v8f32.v8f32.i1.i1.i32.i32(<8 x float> %a, <8 x float> %b, <8 x float> %accum.cvt, i1 true, i1 true, i32 1, i32 1) + %accum.next = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %muladd, i32 1, i32 1, i32 1, i32 0) + + %cc = call i1 @getcc() + br i1 %cc, label %loop, label %end + +end: + call void @lgc.cooperative.matrix.store.p0.i32.i1.i32.i32.i32.v8f32(ptr %ptr, i32 4, i1 true, i32 1, i32 0, i32 0, <8 x float> %accum.next) + ret void +} + +define void @matmul_f16_initzero(ptr %ptr) { +; CHECK-LABEL: @matmul_f16_initzero( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[LOOP:%.*]] +; CHECK: loop: +; CHECK-NEXT: [[ACCUM_PHI:%.*]] = phi <8 x float> [ zeroinitializer, [[ENTRY:%.*]] ], [ [[MULADD:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[A:%.*]] = call <8 x float> @getmat1() +; CHECK-NEXT: [[B:%.*]] = call <8 x float> @getmat1() +; CHECK-NEXT: [[MULADD]] = call <8 x float> @lgc.cooperative.matrix.muladd.v8f32.v8f32.v8f32.v8f32.i1.i1.i32.i32(<8 x float> [[A]], <8 x float> [[B]], <8 x float> [[ACCUM_PHI]], i1 true, i1 true, i32 1, i32 1) +; CHECK-NEXT: [[CC:%.*]] = call i1 @getcc() +; CHECK-NEXT: br i1 [[CC]], label [[LOOP]], label [[END:%.*]] +; CHECK: end: +; CHECK-NEXT: call void @lgc.cooperative.matrix.store.p0.i32.i1.i32.i32.i32.v8f32(ptr [[PTR:%.*]], i32 4, i1 true, i32 1, i32 1, i32 0, <8 x float> [[MULADD]]) +; CHECK-NEXT: ret void +; +entry: + br label %loop + +loop: + %accum.phi = phi <8 x float> [ zeroinitializer, %entry ], [ %accum.next, %loop ] + + %a = call <8 x float> @getmat1() + %b = call <8 x float> @getmat1() + + %accum.cvt = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %accum.phi, i32 1, i32 1, i32 0, i32 1) + %muladd = call <8 x float> @lgc.cooperative.matrix.muladd.v8f32.v8f32.v8f32.v8f32.i1.i1.i32.i32(<8 x float> %a, <8 x float> %b, <8 x float> %accum.cvt, i1 true, i1 true, i32 1, i32 1) + %accum.next = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %muladd, i32 1, i32 1, i32 1, i32 0) + + %cc = call i1 @getcc() + br i1 %cc, label %loop, label %end + +end: + call void @lgc.cooperative.matrix.store.p0.i32.i1.i32.i32.i32.v8f32(ptr %ptr, i32 4, i1 true, i32 1, i32 0, i32 0, <8 x float> %accum.next) + ret void +} + +declare i1 @getcc() +declare <8 x float> @getmat1() + +declare <8 x float> @lgc.cooperative.matrix.load.v8f32.p0.i32.i1.i32.i32.i32(ptr, i32, i1, i32, i32, i32) +declare <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32, <8 x float>, i32, i32, i32, i32) +declare void @lgc.cooperative.matrix.store.p0.i32.i1.i32.i32.i32.v8f32(ptr, i32, i1, i32, i32, i32, <8 x float>) +declare <8 x float> @lgc.cooperative.matrix.muladd.v8f32.v8f32.v8f32.v8f32.i1.i1.i32.i32(<8 x float>, <8 x float>, <8 x float>, i1, i1, i32, i32) diff --git a/lgc/test/Transforms/CombineCooperativeMatrix/simple.lgc b/lgc/test/Transforms/CombineCooperativeMatrix/simple.lgc new file mode 100644 index 0000000000..e1570be25e --- /dev/null +++ b/lgc/test/Transforms/CombineCooperativeMatrix/simple.lgc @@ -0,0 +1,142 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool lgc +; RUN: lgc -o - -passes=lgc-combine-cooperative-matrix %s | FileCheck --check-prefixes=CHECK %s + +define <8 x float> @noop_transpose(<8 x float> %x) { +; CHECK-LABEL: @noop_transpose( +; CHECK-NEXT: [[T:%.*]] = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> [[X:%.*]], i32 1, i32 0) +; CHECK-NEXT: ret <8 x float> [[T]] +; + + %t = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> %x, i32 1, i32 0) + ret <8 x float> %t +} + +define <8 x float> @collapse_transpose(<8 x float> %x) { +; CHECK-LABEL: @collapse_transpose( +; CHECK-NEXT: ret <8 x float> [[X:%.*]] +; + + %t1 = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> %x, i32 1, i32 0) + %t2 = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> %t1, i32 1, i32 0) + ret <8 x float> %t2 +} + +define <8 x float> @test_load_transpose(ptr addrspace(3) %ptr) { +; CHECK-LABEL: @test_load_transpose( +; CHECK-NEXT: [[A:%.*]] = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p3.i32.i1.i32.i32.i32(ptr addrspace(3) [[PTR:%.*]], i32 4, i1 false, i32 1, i32 0, i32 0) +; CHECK-NEXT: ret <8 x float> [[A]] +; + + %a = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p3.i32.i1.i32.i32.i32(ptr addrspace(3) %ptr, i32 4, i1 true, i32 1, i32 0, i32 0) + %t = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> %a, i32 1, i32 0) + ret <8 x float> %t +} + +define void @test_store_transpose(ptr addrspace(3) %ptr, <8 x float> %a) { +; CHECK-LABEL: @test_store_transpose( +; CHECK-NEXT: call void @lgc.cooperative.matrix.store.p3.i32.i1.i32.i32.i32.v8f32(ptr addrspace(3) [[PTR:%.*]], i32 4, i1 false, i32 1, i32 0, i32 0, <8 x float> [[A:%.*]]) +; CHECK-NEXT: ret void +; + + %t = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> %a, i32 1, i32 0) + call void @lgc.cooperative.matrix.store.p3.i32.i1.i32.i32.i32.v8f32(ptr addrspace(3) %ptr, i32 4, i1 true, i32 1, i32 0, i32 0, <8 x float> %t) + ret void +} + +define void @test_phi_transpose(ptr addrspace(7) %ptr, <8 x float> %init) { +; CHECK-LABEL: @test_phi_transpose( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> [[INIT:%.*]], i32 1, i32 0) +; CHECK-NEXT: br label [[LOOP:%.*]] +; CHECK: loop: +; CHECK-NEXT: [[MATRIX:%.*]] = phi <8 x float> [ [[TMP0]], [[ENTRY:%.*]] ], [ [[TMP:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[TMP]] = call <8 x float> @process1(<8 x float> [[MATRIX]]) +; CHECK-NEXT: [[CC:%.*]] = call i1 @getcc() +; CHECK-NEXT: br i1 [[CC]], label [[LOOP]], label [[END:%.*]] +; CHECK: end: +; CHECK-NEXT: call void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8f32(ptr addrspace(7) [[PTR:%.*]], i32 4, i1 true, i32 1, i32 0, i32 0, <8 x float> [[TMP]]) +; CHECK-NEXT: ret void +; +entry: + br label %loop + +loop: + %matrix = phi <8 x float> [ %init, %entry ], [ %matrix.new, %loop ] + %t1 = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> %matrix, i32 1, i32 0) + %tmp = call <8 x float> @process1(<8 x float> %t1) + %matrix.new = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> %tmp, i32 1, i32 0) + + %cc = call i1 @getcc() + br i1 %cc, label %loop, label %end + +end: + call void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8f32(ptr addrspace(7) %ptr, i32 4, i1 false, i32 1, i32 0, i32 0, <8 x float> %matrix.new) + ret void +} + +define <8 x float> @test_relayout_simple(<8 x float> %ab) { +; CHECK-LABEL: @test_relayout_simple( +; CHECK-NEXT: ret <8 x float> [[AB:%.*]] +; + %b = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %ab, i32 1, i32 1, i32 0, i32 1) + %c = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %b, i32 1, i32 1, i32 1, i32 0) + ret <8 x float> %c +} + +define <8 x float> @test_relayout_simple_reverse(<8 x float> %cd) { +; CHECK-LABEL: @test_relayout_simple_reverse( +; CHECK-NEXT: ret <8 x float> [[CD:%.*]] +; + %b = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %cd, i32 1, i32 1, i32 1, i32 0) + %c = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %b, i32 1, i32 1, i32 0, i32 1) + ret <8 x float> %c +} + +define <8 x float> @test_relayout_load(ptr addrspace(3) %ptr) { +; CHECK-LABEL: @test_relayout_load( +; CHECK-NEXT: [[A:%.*]] = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p3.i32.i1.i32.i32.i32(ptr addrspace(3) [[PTR:%.*]], i32 4, i1 true, i32 1, i32 1, i32 0) +; CHECK-NEXT: ret <8 x float> [[A]] +; + %a = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p3.i32.i1.i32.i32.i32(ptr addrspace(3) %ptr, i32 4, i1 true, i32 1, i32 0, i32 0) + %b = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %a, i32 1, i32 1, i32 0, i32 1) + ret <8 x float> %b +} + +define <8 x float> @test_relayout_load2(ptr addrspace(3) %ptr) { +; CHECK-LABEL: @test_relayout_load2( +; CHECK-NEXT: [[A:%.*]] = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p3.i32.i1.i32.i32.i32(ptr addrspace(3) [[PTR:%.*]], i32 4, i1 true, i32 1, i32 0, i32 0) +; CHECK-NEXT: ret <8 x float> [[A]] +; + %a = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p3.i32.i1.i32.i32.i32(ptr addrspace(3) %ptr, i32 4, i1 true, i32 1, i32 1, i32 0) + %b = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %a, i32 1, i32 1, i32 1, i32 0) + ret <8 x float> %b +} + +define void @test_relayout_store(ptr addrspace(3) %ptr, <8 x float> %a) { +; CHECK-LABEL: @test_relayout_store( +; CHECK-NEXT: call void @lgc.cooperative.matrix.store.p3.i32.i1.i32.i32.i32.v8f32(ptr addrspace(3) [[PTR:%.*]], i32 4, i1 true, i32 1, i32 0, i32 0, <8 x float> [[A:%.*]]) +; CHECK-NEXT: ret void +; + %b = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %a, i32 1, i32 1, i32 0, i32 1) + call void @lgc.cooperative.matrix.store.p3.i32.i1.i32.i32.i32.v8f32(ptr addrspace(3) %ptr, i32 4, i1 true, i32 1, i32 1, i32 0, <8 x float> %b) + ret void +} + +define void @test_relayout_store2(ptr addrspace(3) %ptr, <8 x float> %a) { +; CHECK-LABEL: @test_relayout_store2( +; CHECK-NEXT: call void @lgc.cooperative.matrix.store.p3.i32.i1.i32.i32.i32.v8f32(ptr addrspace(3) [[PTR:%.*]], i32 4, i1 true, i32 1, i32 1, i32 0, <8 x float> [[A:%.*]]) +; CHECK-NEXT: ret void +; + %b = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %a, i32 1, i32 1, i32 1, i32 0) + call void @lgc.cooperative.matrix.store.p3.i32.i1.i32.i32.i32.v8f32(ptr addrspace(3) %ptr, i32 4, i1 true, i32 1, i32 0, i32 0, <8 x float> %b) + ret void +} + +declare i1 @getcc() +declare <8 x float> @process1(<8 x float>) + +declare <8 x float> @lgc.cooperative.matrix.load.v8f32.p3.i32.i1.i32.i32.i32(ptr addrspace(3), i32, i1, i32, i32, i32) +declare <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float>, i32, i32) +declare <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32, <8 x float>, i32, i32, i32, i32) +declare void @lgc.cooperative.matrix.store.p3.i32.i1.i32.i32.i32.v8f32(ptr addrspace(3), i32, i1, i32, i32, i32, <8 x float>) +declare void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8f32(ptr addrspace(7), i32, i1, i32, i32, i32, <8 x float>) diff --git a/lgc/test/Transforms/CombineCooperativeMatrix/unhandled-inout.lgc b/lgc/test/Transforms/CombineCooperativeMatrix/unhandled-inout.lgc new file mode 100644 index 0000000000..6bc9523ab9 --- /dev/null +++ b/lgc/test/Transforms/CombineCooperativeMatrix/unhandled-inout.lgc @@ -0,0 +1,110 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool lgc +; RUN: lgc -o - -passes=lgc-combine-cooperative-matrix %s | FileCheck --check-prefixes=CHECK %s + +define <8 x float> @insert_transpose(<8 x float> %x) { +; CHECK-LABEL: @insert_transpose( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[GUARD:%.*]] = call i1 @getcc() +; CHECK-NEXT: br i1 [[GUARD]], label [[LOOP:%.*]], label [[END:%.*]] +; CHECK: loop: +; CHECK-NEXT: [[V_LOOP:%.*]] = phi <8 x float> [ [[X:%.*]], [[ENTRY:%.*]] ], [ [[MULADD:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[F:%.*]] = call <8 x float> @getmat1() +; CHECK-NEXT: [[MULADD]] = call <8 x float> @lgc.cooperative.matrix.muladd.v8f32.v8f32.v8f32.v8f32.i1.i1.i32.i32(<8 x float> [[F]], <8 x float> [[F]], <8 x float> [[V_LOOP]], i1 true, i1 true, i32 1, i32 1) +; CHECK-NEXT: [[CC:%.*]] = call i1 @getcc() +; CHECK-NEXT: br i1 [[CC]], label [[LOOP]], label [[END]] +; CHECK: end: +; CHECK-NEXT: [[R:%.*]] = phi <8 x float> [ [[MULADD]], [[LOOP]] ], [ [[X]], [[ENTRY]] ] +; CHECK-NEXT: [[TMP0:%.*]] = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> [[R]], i32 1, i32 0) +; CHECK-NEXT: ret <8 x float> [[TMP0]] +; +entry: + %in.t = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> %x, i32 1, i32 0) + %guard = call i1 @getcc() + br i1 %guard, label %loop, label %end + +loop: + %v.loop = phi <8 x float> [ %in.t, %entry ], [ %v.next, %loop ] + + %f = call <8 x float> @getmat1() + %pre.t = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> %v.loop, i32 1, i32 0) + %muladd = call <8 x float> @lgc.cooperative.matrix.muladd.v8f32.v8f32.v8f32.v8f32.i1.i1.i32.i32(<8 x float> %f, <8 x float> %f, <8 x float> %pre.t, i1 true, i1 true, i32 1, i32 1) + %v.next = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> %muladd, i32 1, i32 0) + + %cc = call i1 @getcc() + br i1 %cc, label %loop, label %end + +end: + %r = phi <8 x float> [ %v.next, %loop ], [ %in.t, %entry ] + ret <8 x float> %r +} + +define <8 x float> @reuse_transpose(<8 x float> %x) { +; CHECK-LABEL: @reuse_transpose( +; CHECK-NEXT: [[T1:%.*]] = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> [[X:%.*]], i32 1, i32 0) +; CHECK-NEXT: [[R:%.*]] = call <8 x float> @lgc.cooperative.matrix.muladd.v8f32.v8f32.v8f32.v8f32.i1.i1.i32.i32(<8 x float> [[T1]], <8 x float> [[X]], <8 x float> zeroinitializer, i1 true, i1 true, i32 1, i32 1) +; CHECK-NEXT: ret <8 x float> [[R]] +; + %t1 = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> %x, i32 1, i32 0) + %t2 = call <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float> %t1, i32 1, i32 0) + %r = call <8 x float> @lgc.cooperative.matrix.muladd.v8f32.v8f32.v8f32.v8f32.i1.i1.i32.i32(<8 x float> %t1, <8 x float> %t2, <8 x float> zeroinitializer, i1 true, i1 true, i32 1, i32 1) + ret <8 x float> %r +} + +; NOTE: This test leaves a convert inside the loop. Rely on other passes to sink it out. +define <8 x float> @insert_convert(ptr %ptr) { +; CHECK-LABEL: @insert_convert( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[LOAD:%.*]] = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p0.i32.i1.i32.i32.i32(ptr [[PTR:%.*]], i32 4, i1 false, i32 1, i32 1, i32 0) +; CHECK-NEXT: [[GUARD:%.*]] = call i1 @getcc() +; CHECK-NEXT: br i1 [[GUARD]], label [[LOOP:%.*]], label [[END:%.*]] +; CHECK: loop: +; CHECK-NEXT: [[V_LOOP:%.*]] = phi <8 x float> [ [[LOAD]], [[ENTRY:%.*]] ], [ [[MULADD:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[F:%.*]] = call <8 x float> @getmat1() +; CHECK-NEXT: [[MULADD]] = call <8 x float> @lgc.cooperative.matrix.muladd.v8f32.v8f32.v8f32.v8f32.i1.i1.i32.i32(<8 x float> [[F]], <8 x float> [[F]], <8 x float> [[V_LOOP]], i1 true, i1 true, i32 1, i32 1) +; CHECK-NEXT: [[CC:%.*]] = call i1 @getcc() +; CHECK-NEXT: br i1 [[CC]], label [[LOOP]], label [[END]] +; CHECK: end: +; CHECK-NEXT: [[R:%.*]] = phi <8 x float> [ [[MULADD]], [[LOOP]] ], [ [[LOAD]], [[ENTRY]] ] +; CHECK-NEXT: [[TMP0:%.*]] = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> [[R]], i32 1, i32 1, i32 1, i32 0) +; CHECK-NEXT: ret <8 x float> [[TMP0]] +; +entry: + %load = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p0.i32.i1.i32.i32.i32(ptr %ptr, i32 4, i1 false, i32 1, i32 0, i32 0) + %guard = call i1 @getcc() + br i1 %guard, label %loop, label %end + +loop: + %v.loop = phi <8 x float> [ %load, %entry ], [ %v.next, %loop ] + + %f = call <8 x float> @getmat1() + %pre = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %v.loop, i32 1, i32 1, i32 0, i32 1) + %muladd = call <8 x float> @lgc.cooperative.matrix.muladd.v8f32.v8f32.v8f32.v8f32.i1.i1.i32.i32(<8 x float> %f, <8 x float> %f, <8 x float> %pre, i1 true, i1 true, i32 1, i32 1) + %v.next = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %muladd, i32 1, i32 1, i32 1, i32 0) + + %cc = call i1 @getcc() + br i1 %cc, label %loop, label %end + +end: + %r = phi <8 x float> [ %v.next, %loop ], [ %load, %entry ] + ret <8 x float> %r +} + +define <8 x float> @reuse_convert(<8 x float> %x) { +; CHECK-LABEL: @reuse_convert( +; CHECK-NEXT: [[CVT1:%.*]] = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> [[X:%.*]], i32 1, i32 1, i32 0, i32 1) +; CHECK-NEXT: [[R:%.*]] = call <8 x float> @lgc.cooperative.matrix.muladd.v8f32.v8f32.v8f32.v8f32.i1.i1.i32.i32(<8 x float> [[X]], <8 x float> [[X]], <8 x float> [[CVT1]], i1 true, i1 true, i32 1, i32 1) +; CHECK-NEXT: ret <8 x float> [[R]] +; + %cvt1 = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %x, i32 1, i32 1, i32 0, i32 1) + %cvt2 = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %cvt1, i32 1, i32 1, i32 1, i32 0) + %r = call <8 x float> @lgc.cooperative.matrix.muladd.v8f32.v8f32.v8f32.v8f32.i1.i1.i32.i32(<8 x float> %cvt2, <8 x float> %cvt2, <8 x float> %cvt1, i1 true, i1 true, i32 1, i32 1) + ret <8 x float> %r +} + +declare i1 @getcc() +declare <8 x float> @getmat1() + +declare <8 x float> @lgc.cooperative.matrix.load.v8f32.p0.i32.i1.i32.i32.i32(ptr, i32, i1, i32, i32, i32) +declare <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float>, i32, i32) +declare <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32, <8 x float>, i32, i32, i32, i32) +declare <8 x float> @lgc.cooperative.matrix.muladd.v8f32.v8f32.v8f32.v8f32.i1.i1.i32.i32(<8 x float>, <8 x float>, <8 x float>, i1, i1, i32, i32) diff --git a/lgc/test/Transforms/LowerCooperativeMatrix/convert.lgc b/lgc/test/Transforms/LowerCooperativeMatrix/convert.lgc new file mode 100644 index 0000000000..dcbf7ddee2 --- /dev/null +++ b/lgc/test/Transforms/LowerCooperativeMatrix/convert.lgc @@ -0,0 +1,245 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool lgc +; RUN: lgc -o - -passes=lgc-lower-cooperative-matrix %s | FileCheck --check-prefixes=CHECK %s + +define <8 x float> @test_relayout_simple(<8 x float> %ab) { +; CHECK-LABEL: @test_relayout_simple( +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = udiv i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = and i32 [[TMP3]], 1 +; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i32 [[TMP4]], 0 +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <8 x float> [[AB:%.*]] to <8 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <8 x float> [[AB]], i64 0 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <8 x float> [[AB]], i64 1 +; CHECK-NEXT: [[TMP9:%.*]] = and i32 [[TMP3]], 2 +; CHECK-NEXT: [[TMP10:%.*]] = icmp eq i32 [[TMP9]], 0 +; CHECK-NEXT: [[TMP11:%.*]] = select i1 [[TMP10]], float [[TMP7]], float [[TMP8]] +; CHECK-NEXT: [[B1:%.*]] = insertelement <4 x float> poison, float [[TMP11]], i64 0 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <8 x float> [[AB]], i64 2 +; CHECK-NEXT: [[TMP13:%.*]] = extractelement <8 x float> [[AB]], i64 3 +; CHECK-NEXT: [[TMP14:%.*]] = and i32 [[TMP3]], 2 +; CHECK-NEXT: [[TMP15:%.*]] = icmp eq i32 [[TMP14]], 0 +; CHECK-NEXT: [[TMP16:%.*]] = select i1 [[TMP15]], float [[TMP12]], float [[TMP13]] +; CHECK-NEXT: [[B2:%.*]] = insertelement <4 x float> [[B1]], float [[TMP16]], i64 1 +; CHECK-NEXT: [[TMP17:%.*]] = extractelement <8 x float> [[AB]], i64 4 +; CHECK-NEXT: [[TMP18:%.*]] = extractelement <8 x float> [[AB]], i64 5 +; CHECK-NEXT: [[TMP19:%.*]] = and i32 [[TMP3]], 2 +; CHECK-NEXT: [[TMP20:%.*]] = icmp eq i32 [[TMP19]], 0 +; CHECK-NEXT: [[TMP21:%.*]] = select i1 [[TMP20]], float [[TMP17]], float [[TMP18]] +; CHECK-NEXT: [[B3:%.*]] = insertelement <4 x float> [[B2]], float [[TMP21]], i64 2 +; CHECK-NEXT: [[TMP22:%.*]] = extractelement <8 x float> [[AB]], i64 6 +; CHECK-NEXT: [[TMP23:%.*]] = extractelement <8 x float> [[AB]], i64 7 +; CHECK-NEXT: [[TMP24:%.*]] = and i32 [[TMP3]], 2 +; CHECK-NEXT: [[TMP25:%.*]] = icmp eq i32 [[TMP24]], 0 +; CHECK-NEXT: [[TMP26:%.*]] = select i1 [[TMP25]], float [[TMP22]], float [[TMP23]] +; CHECK-NEXT: [[B4:%.*]] = insertelement <4 x float> [[B3]], float [[TMP26]], i64 3 +; CHECK-NEXT: [[TMP27:%.*]] = bitcast <4 x float> [[B4]] to <4 x i32> +; CHECK-NEXT: [[TMP28:%.*]] = select i1 [[TMP5]], <4 x i32> zeroinitializer, <4 x i32> +; CHECK-NEXT: [[B5:%.*]] = lshr <4 x i32> [[TMP27]], [[TMP28]] +; CHECK-NEXT: [[B6:%.*]] = bitcast <4 x i32> [[B5]] to <4 x float> +; CHECK-NEXT: [[TMP29:%.*]] = shufflevector <4 x float> [[B6]], <4 x float> poison, <8 x i32> +; CHECK-NEXT: [[TMP30:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP31:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP30]]) +; CHECK-NEXT: [[TMP32:%.*]] = bitcast <8 x float> [[TMP29]] to <8 x i32> +; CHECK-NEXT: [[TMP33:%.*]] = udiv i32 [[TMP31]], 16 +; CHECK-NEXT: [[TMP34:%.*]] = and i32 [[TMP33]], 1 +; CHECK-NEXT: [[TMP35:%.*]] = icmp eq i32 [[TMP34]], 0 +; CHECK-NEXT: [[TMP36:%.*]] = shufflevector <8 x i32> [[TMP32]], <8 x i32> poison, <8 x i32> +; CHECK-NEXT: [[TMP37:%.*]] = extractelement <8 x i32> [[TMP36]], i64 0 +; CHECK-NEXT: [[TMP38:%.*]] = extractelement <8 x i32> [[TMP36]], i64 0 +; CHECK-NEXT: [[TMP39:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP37]], i32 [[TMP38]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP40:%.*]] = extractelement <8 x i32> [[TMP36]], i64 1 +; CHECK-NEXT: [[TMP41:%.*]] = extractelement <8 x i32> [[TMP36]], i64 1 +; CHECK-NEXT: [[TMP42:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP40]], i32 [[TMP41]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP43:%.*]] = extractelement <8 x i32> [[TMP36]], i64 2 +; CHECK-NEXT: [[TMP44:%.*]] = extractelement <8 x i32> [[TMP36]], i64 2 +; CHECK-NEXT: [[TMP45:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP43]], i32 [[TMP44]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP46:%.*]] = extractelement <8 x i32> [[TMP36]], i64 3 +; CHECK-NEXT: [[TMP47:%.*]] = extractelement <8 x i32> [[TMP36]], i64 3 +; CHECK-NEXT: [[TMP48:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP46]], i32 [[TMP47]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP49:%.*]] = extractelement <8 x i32> [[TMP36]], i64 4 +; CHECK-NEXT: [[TMP50:%.*]] = extractelement <8 x i32> [[TMP36]], i64 4 +; CHECK-NEXT: [[TMP51:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP49]], i32 [[TMP50]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP52:%.*]] = extractelement <8 x i32> [[TMP36]], i64 5 +; CHECK-NEXT: [[TMP53:%.*]] = extractelement <8 x i32> [[TMP36]], i64 5 +; CHECK-NEXT: [[TMP54:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP52]], i32 [[TMP53]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP55:%.*]] = extractelement <8 x i32> [[TMP36]], i64 6 +; CHECK-NEXT: [[TMP56:%.*]] = extractelement <8 x i32> [[TMP36]], i64 6 +; CHECK-NEXT: [[TMP57:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP55]], i32 [[TMP56]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP58:%.*]] = extractelement <8 x i32> [[TMP36]], i64 7 +; CHECK-NEXT: [[TMP59:%.*]] = extractelement <8 x i32> [[TMP36]], i64 7 +; CHECK-NEXT: [[TMP60:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP58]], i32 [[TMP59]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP61:%.*]] = insertelement <8 x i32> poison, i32 [[TMP39]], i64 0 +; CHECK-NEXT: [[TMP62:%.*]] = insertelement <8 x i32> [[TMP61]], i32 [[TMP42]], i64 1 +; CHECK-NEXT: [[TMP63:%.*]] = insertelement <8 x i32> [[TMP62]], i32 [[TMP45]], i64 2 +; CHECK-NEXT: [[TMP64:%.*]] = insertelement <8 x i32> [[TMP63]], i32 [[TMP48]], i64 3 +; CHECK-NEXT: [[TMP65:%.*]] = insertelement <8 x i32> [[TMP64]], i32 [[TMP51]], i64 4 +; CHECK-NEXT: [[TMP66:%.*]] = insertelement <8 x i32> [[TMP65]], i32 [[TMP54]], i64 5 +; CHECK-NEXT: [[TMP67:%.*]] = insertelement <8 x i32> [[TMP66]], i32 [[TMP57]], i64 6 +; CHECK-NEXT: [[TMP68:%.*]] = insertelement <8 x i32> [[TMP67]], i32 [[TMP60]], i64 7 +; CHECK-NEXT: [[TMP69:%.*]] = select i1 [[TMP35]], <8 x i32> [[TMP36]], <8 x i32> [[TMP68]] +; CHECK-NEXT: [[TMP70:%.*]] = select i1 [[TMP35]], <8 x i32> [[TMP68]], <8 x i32> [[TMP36]] +; CHECK-NEXT: [[TMP71:%.*]] = and <8 x i32> [[TMP69]], +; CHECK-NEXT: [[TMP72:%.*]] = shl <8 x i32> [[TMP70]], +; CHECK-NEXT: [[TMP73:%.*]] = or <8 x i32> [[TMP71]], [[TMP72]] +; CHECK-NEXT: [[TMP74:%.*]] = extractelement <8 x i32> [[TMP73]], i64 0 +; CHECK-NEXT: [[TMP75:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP74]]) +; CHECK-NEXT: [[TMP76:%.*]] = extractelement <8 x i32> [[TMP73]], i64 1 +; CHECK-NEXT: [[TMP77:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP76]]) +; CHECK-NEXT: [[TMP78:%.*]] = extractelement <8 x i32> [[TMP73]], i64 2 +; CHECK-NEXT: [[TMP79:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP78]]) +; CHECK-NEXT: [[TMP80:%.*]] = extractelement <8 x i32> [[TMP73]], i64 3 +; CHECK-NEXT: [[TMP81:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP80]]) +; CHECK-NEXT: [[TMP82:%.*]] = extractelement <8 x i32> [[TMP73]], i64 4 +; CHECK-NEXT: [[TMP83:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP82]]) +; CHECK-NEXT: [[TMP84:%.*]] = extractelement <8 x i32> [[TMP73]], i64 5 +; CHECK-NEXT: [[TMP85:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP84]]) +; CHECK-NEXT: [[TMP86:%.*]] = extractelement <8 x i32> [[TMP73]], i64 6 +; CHECK-NEXT: [[TMP87:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP86]]) +; CHECK-NEXT: [[TMP88:%.*]] = extractelement <8 x i32> [[TMP73]], i64 7 +; CHECK-NEXT: [[TMP89:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP88]]) +; CHECK-NEXT: [[TMP90:%.*]] = insertelement <8 x i32> poison, i32 [[TMP75]], i64 0 +; CHECK-NEXT: [[TMP91:%.*]] = insertelement <8 x i32> [[TMP90]], i32 [[TMP77]], i64 1 +; CHECK-NEXT: [[TMP92:%.*]] = insertelement <8 x i32> [[TMP91]], i32 [[TMP79]], i64 2 +; CHECK-NEXT: [[TMP93:%.*]] = insertelement <8 x i32> [[TMP92]], i32 [[TMP81]], i64 3 +; CHECK-NEXT: [[TMP94:%.*]] = insertelement <8 x i32> [[TMP93]], i32 [[TMP83]], i64 4 +; CHECK-NEXT: [[TMP95:%.*]] = insertelement <8 x i32> [[TMP94]], i32 [[TMP85]], i64 5 +; CHECK-NEXT: [[TMP96:%.*]] = insertelement <8 x i32> [[TMP95]], i32 [[TMP87]], i64 6 +; CHECK-NEXT: [[TMP97:%.*]] = insertelement <8 x i32> [[TMP96]], i32 [[TMP89]], i64 7 +; CHECK-NEXT: [[TMP98:%.*]] = icmp ult i32 [[TMP31]], 32 +; CHECK-NEXT: [[TMP99:%.*]] = select i1 [[TMP98]], <8 x i32> [[TMP73]], <8 x i32> [[TMP97]] +; CHECK-NEXT: [[TMP100:%.*]] = select i1 [[TMP98]], <8 x i32> [[TMP97]], <8 x i32> [[TMP73]] +; CHECK-NEXT: [[C7:%.*]] = shufflevector <8 x i32> [[TMP99]], <8 x i32> [[TMP100]], <8 x i32> +; CHECK-NEXT: [[TMP101:%.*]] = bitcast <8 x i32> [[C7]] to <8 x float> +; CHECK-NEXT: ret <8 x float> [[TMP101]] +; + %b = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %ab, i32 1, i32 1, i32 0, i32 1) + %c = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %b, i32 1, i32 1, i32 1, i32 0) + ret <8 x float> %c +} + +define <8 x float> @test_relayout_simple_reverse(<8 x float> %cd) { +; CHECK-LABEL: @test_relayout_simple_reverse( +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = bitcast <8 x float> [[CD:%.*]] to <8 x i32> +; CHECK-NEXT: [[TMP4:%.*]] = udiv i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 1 +; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i32 [[TMP5]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <8 x i32> [[TMP3]], <8 x i32> poison, <8 x i32> +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <8 x i32> [[TMP7]], i64 0 +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <8 x i32> [[TMP7]], i64 0 +; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP8]], i32 [[TMP9]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP11:%.*]] = extractelement <8 x i32> [[TMP7]], i64 1 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <8 x i32> [[TMP7]], i64 1 +; CHECK-NEXT: [[TMP13:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP11]], i32 [[TMP12]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <8 x i32> [[TMP7]], i64 2 +; CHECK-NEXT: [[TMP15:%.*]] = extractelement <8 x i32> [[TMP7]], i64 2 +; CHECK-NEXT: [[TMP16:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP14]], i32 [[TMP15]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP17:%.*]] = extractelement <8 x i32> [[TMP7]], i64 3 +; CHECK-NEXT: [[TMP18:%.*]] = extractelement <8 x i32> [[TMP7]], i64 3 +; CHECK-NEXT: [[TMP19:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP17]], i32 [[TMP18]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP20:%.*]] = extractelement <8 x i32> [[TMP7]], i64 4 +; CHECK-NEXT: [[TMP21:%.*]] = extractelement <8 x i32> [[TMP7]], i64 4 +; CHECK-NEXT: [[TMP22:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP20]], i32 [[TMP21]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP23:%.*]] = extractelement <8 x i32> [[TMP7]], i64 5 +; CHECK-NEXT: [[TMP24:%.*]] = extractelement <8 x i32> [[TMP7]], i64 5 +; CHECK-NEXT: [[TMP25:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP23]], i32 [[TMP24]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP26:%.*]] = extractelement <8 x i32> [[TMP7]], i64 6 +; CHECK-NEXT: [[TMP27:%.*]] = extractelement <8 x i32> [[TMP7]], i64 6 +; CHECK-NEXT: [[TMP28:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP26]], i32 [[TMP27]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP29:%.*]] = extractelement <8 x i32> [[TMP7]], i64 7 +; CHECK-NEXT: [[TMP30:%.*]] = extractelement <8 x i32> [[TMP7]], i64 7 +; CHECK-NEXT: [[TMP31:%.*]] = call i32 @llvm.amdgcn.permlanex16(i32 [[TMP29]], i32 [[TMP30]], i32 1985229328, i32 -19088744, i1 false, i1 false) +; CHECK-NEXT: [[TMP32:%.*]] = insertelement <8 x i32> poison, i32 [[TMP10]], i64 0 +; CHECK-NEXT: [[TMP33:%.*]] = insertelement <8 x i32> [[TMP32]], i32 [[TMP13]], i64 1 +; CHECK-NEXT: [[TMP34:%.*]] = insertelement <8 x i32> [[TMP33]], i32 [[TMP16]], i64 2 +; CHECK-NEXT: [[TMP35:%.*]] = insertelement <8 x i32> [[TMP34]], i32 [[TMP19]], i64 3 +; CHECK-NEXT: [[TMP36:%.*]] = insertelement <8 x i32> [[TMP35]], i32 [[TMP22]], i64 4 +; CHECK-NEXT: [[TMP37:%.*]] = insertelement <8 x i32> [[TMP36]], i32 [[TMP25]], i64 5 +; CHECK-NEXT: [[TMP38:%.*]] = insertelement <8 x i32> [[TMP37]], i32 [[TMP28]], i64 6 +; CHECK-NEXT: [[TMP39:%.*]] = insertelement <8 x i32> [[TMP38]], i32 [[TMP31]], i64 7 +; CHECK-NEXT: [[TMP40:%.*]] = select i1 [[TMP6]], <8 x i32> [[TMP7]], <8 x i32> [[TMP39]] +; CHECK-NEXT: [[TMP41:%.*]] = select i1 [[TMP6]], <8 x i32> [[TMP39]], <8 x i32> [[TMP7]] +; CHECK-NEXT: [[TMP42:%.*]] = and <8 x i32> [[TMP40]], +; CHECK-NEXT: [[TMP43:%.*]] = shl <8 x i32> [[TMP41]], +; CHECK-NEXT: [[TMP44:%.*]] = or <8 x i32> [[TMP42]], [[TMP43]] +; CHECK-NEXT: [[TMP45:%.*]] = extractelement <8 x i32> [[TMP44]], i64 0 +; CHECK-NEXT: [[TMP46:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP45]]) +; CHECK-NEXT: [[TMP47:%.*]] = extractelement <8 x i32> [[TMP44]], i64 1 +; CHECK-NEXT: [[TMP48:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP47]]) +; CHECK-NEXT: [[TMP49:%.*]] = extractelement <8 x i32> [[TMP44]], i64 2 +; CHECK-NEXT: [[TMP50:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP49]]) +; CHECK-NEXT: [[TMP51:%.*]] = extractelement <8 x i32> [[TMP44]], i64 3 +; CHECK-NEXT: [[TMP52:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP51]]) +; CHECK-NEXT: [[TMP53:%.*]] = extractelement <8 x i32> [[TMP44]], i64 4 +; CHECK-NEXT: [[TMP54:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP53]]) +; CHECK-NEXT: [[TMP55:%.*]] = extractelement <8 x i32> [[TMP44]], i64 5 +; CHECK-NEXT: [[TMP56:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP55]]) +; CHECK-NEXT: [[TMP57:%.*]] = extractelement <8 x i32> [[TMP44]], i64 6 +; CHECK-NEXT: [[TMP58:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP57]]) +; CHECK-NEXT: [[TMP59:%.*]] = extractelement <8 x i32> [[TMP44]], i64 7 +; CHECK-NEXT: [[TMP60:%.*]] = call i32 @llvm.amdgcn.permlane64(i32 [[TMP59]]) +; CHECK-NEXT: [[TMP61:%.*]] = insertelement <8 x i32> poison, i32 [[TMP46]], i64 0 +; CHECK-NEXT: [[TMP62:%.*]] = insertelement <8 x i32> [[TMP61]], i32 [[TMP48]], i64 1 +; CHECK-NEXT: [[TMP63:%.*]] = insertelement <8 x i32> [[TMP62]], i32 [[TMP50]], i64 2 +; CHECK-NEXT: [[TMP64:%.*]] = insertelement <8 x i32> [[TMP63]], i32 [[TMP52]], i64 3 +; CHECK-NEXT: [[TMP65:%.*]] = insertelement <8 x i32> [[TMP64]], i32 [[TMP54]], i64 4 +; CHECK-NEXT: [[TMP66:%.*]] = insertelement <8 x i32> [[TMP65]], i32 [[TMP56]], i64 5 +; CHECK-NEXT: [[TMP67:%.*]] = insertelement <8 x i32> [[TMP66]], i32 [[TMP58]], i64 6 +; CHECK-NEXT: [[TMP68:%.*]] = insertelement <8 x i32> [[TMP67]], i32 [[TMP60]], i64 7 +; CHECK-NEXT: [[TMP69:%.*]] = icmp ult i32 [[TMP2]], 32 +; CHECK-NEXT: [[TMP70:%.*]] = select i1 [[TMP69]], <8 x i32> [[TMP44]], <8 x i32> [[TMP68]] +; CHECK-NEXT: [[TMP71:%.*]] = select i1 [[TMP69]], <8 x i32> [[TMP68]], <8 x i32> [[TMP44]] +; CHECK-NEXT: [[B1:%.*]] = shufflevector <8 x i32> [[TMP70]], <8 x i32> [[TMP71]], <8 x i32> +; CHECK-NEXT: [[TMP72:%.*]] = bitcast <8 x i32> [[B1]] to <8 x float> +; CHECK-NEXT: [[TMP73:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP74:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP73]]) +; CHECK-NEXT: [[TMP75:%.*]] = udiv i32 [[TMP74]], 16 +; CHECK-NEXT: [[TMP76:%.*]] = and i32 [[TMP75]], 1 +; CHECK-NEXT: [[TMP77:%.*]] = icmp eq i32 [[TMP76]], 0 +; CHECK-NEXT: [[TMP78:%.*]] = bitcast <8 x float> [[TMP72]] to <8 x i32> +; CHECK-NEXT: [[TMP79:%.*]] = extractelement <8 x float> [[TMP72]], i64 0 +; CHECK-NEXT: [[TMP80:%.*]] = extractelement <8 x float> [[TMP72]], i64 1 +; CHECK-NEXT: [[TMP81:%.*]] = and i32 [[TMP75]], 2 +; CHECK-NEXT: [[TMP82:%.*]] = icmp eq i32 [[TMP81]], 0 +; CHECK-NEXT: [[TMP83:%.*]] = select i1 [[TMP82]], float [[TMP79]], float [[TMP80]] +; CHECK-NEXT: [[C2:%.*]] = insertelement <4 x float> poison, float [[TMP83]], i64 0 +; CHECK-NEXT: [[TMP84:%.*]] = extractelement <8 x float> [[TMP72]], i64 2 +; CHECK-NEXT: [[TMP85:%.*]] = extractelement <8 x float> [[TMP72]], i64 3 +; CHECK-NEXT: [[TMP86:%.*]] = and i32 [[TMP75]], 2 +; CHECK-NEXT: [[TMP87:%.*]] = icmp eq i32 [[TMP86]], 0 +; CHECK-NEXT: [[TMP88:%.*]] = select i1 [[TMP87]], float [[TMP84]], float [[TMP85]] +; CHECK-NEXT: [[C3:%.*]] = insertelement <4 x float> [[C2]], float [[TMP88]], i64 1 +; CHECK-NEXT: [[TMP89:%.*]] = extractelement <8 x float> [[TMP72]], i64 4 +; CHECK-NEXT: [[TMP90:%.*]] = extractelement <8 x float> [[TMP72]], i64 5 +; CHECK-NEXT: [[TMP91:%.*]] = and i32 [[TMP75]], 2 +; CHECK-NEXT: [[TMP92:%.*]] = icmp eq i32 [[TMP91]], 0 +; CHECK-NEXT: [[TMP93:%.*]] = select i1 [[TMP92]], float [[TMP89]], float [[TMP90]] +; CHECK-NEXT: [[C4:%.*]] = insertelement <4 x float> [[C3]], float [[TMP93]], i64 2 +; CHECK-NEXT: [[TMP94:%.*]] = extractelement <8 x float> [[TMP72]], i64 6 +; CHECK-NEXT: [[TMP95:%.*]] = extractelement <8 x float> [[TMP72]], i64 7 +; CHECK-NEXT: [[TMP96:%.*]] = and i32 [[TMP75]], 2 +; CHECK-NEXT: [[TMP97:%.*]] = icmp eq i32 [[TMP96]], 0 +; CHECK-NEXT: [[TMP98:%.*]] = select i1 [[TMP97]], float [[TMP94]], float [[TMP95]] +; CHECK-NEXT: [[C5:%.*]] = insertelement <4 x float> [[C4]], float [[TMP98]], i64 3 +; CHECK-NEXT: [[TMP99:%.*]] = bitcast <4 x float> [[C5]] to <4 x i32> +; CHECK-NEXT: [[TMP100:%.*]] = select i1 [[TMP77]], <4 x i32> zeroinitializer, <4 x i32> +; CHECK-NEXT: [[C6:%.*]] = lshr <4 x i32> [[TMP99]], [[TMP100]] +; CHECK-NEXT: [[C7:%.*]] = bitcast <4 x i32> [[C6]] to <4 x float> +; CHECK-NEXT: [[TMP101:%.*]] = shufflevector <4 x float> [[C7]], <4 x float> poison, <8 x i32> +; CHECK-NEXT: ret <8 x float> [[TMP101]] +; + %b = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %cd, i32 1, i32 1, i32 1, i32 0) + %c = call <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32 0, <8 x float> %b, i32 1, i32 1, i32 0, i32 1) + ret <8 x float> %c +} + +declare i1 @getcc() +declare <8 x float> @process1(<8 x float>) + +declare <8 x float> @lgc.cooperative.matrix.load.v8f32.p3.i32.i1.i32.i32.i32(ptr addrspace(3), i32, i1, i32, i32, i32) +declare <8 x float> @lgc.cooperative.matrix.transpose.v8f32.v8f32.i32.i32(<8 x float>, i32, i32) +declare <8 x float> @lgc.cooperative.matrix.convert.v8f32.i32.v8f32.i32.i32.i32.i32(i32, <8 x float>, i32, i32, i32, i32) +declare void @lgc.cooperative.matrix.store.p3.i32.i1.i32.i32.i32.v8f32(ptr addrspace(3), i32, i1, i32, i32, i32, <8 x float>) +declare void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8f32(ptr addrspace(7), i32, i1, i32, i32, i32, <8 x float>) diff --git a/lgc/test/Transforms/LowerCooperativeMatrix/extract-insert.lgc b/lgc/test/Transforms/LowerCooperativeMatrix/extract-insert.lgc new file mode 100644 index 0000000000..6ddd7c21dd --- /dev/null +++ b/lgc/test/Transforms/LowerCooperativeMatrix/extract-insert.lgc @@ -0,0 +1,47 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool lgc +; RUN: lgc -o - -passes=lgc-lower-cooperative-matrix %s | FileCheck --check-prefixes=CHECK %s + +define i32 @test_length_f16() !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_length_f16( +; CHECK-NEXT: ret i32 16 +; + %a = call i32 @lgc.cooperative.matrix.length.i32.i32.i32(i32 1, i32 0) + ret i32 %a +} + +define half @test_extract_f16(<8 x float> %matrix) !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_extract_f16( +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x float> [[MATRIX:%.*]] to <16 x half> +; CHECK-NEXT: [[R:%.*]] = extractelement <16 x half> [[TMP1]], i32 5 +; CHECK-NEXT: ret half [[R]] +; + %r = call half @lgc.cooperative.matrix.extract.f16.v8f32.i32.i32.i32(<8 x float> %matrix, i32 5, i32 1, i32 0) + ret half %r +} + +define <8 x float> @test_insert_f16(<8 x float> %matrix, half %x) !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_insert_f16( +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x float> [[MATRIX:%.*]] to <16 x half> +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <16 x half> [[TMP1]], half [[X:%.*]], i32 5 +; CHECK-NEXT: [[R:%.*]] = bitcast <16 x half> [[TMP2]] to <8 x float> +; CHECK-NEXT: ret <8 x float> [[R]] +; + %r = call <8 x float> @lgc.cooperative.matrix.insert.v8f32.v8f32.f16.i32.i32.i32(<8 x float> %matrix, half %x, i32 5, i32 1, i32 0) + ret <8 x float> %r +} + +declare i32 @lgc.cooperative.matrix.length.i32.i32.i32(i32, i32) +declare half @lgc.cooperative.matrix.extract.f16.v8f32.i32.i32.i32(<8 x float>, i32, i32, i32) +declare <8 x float> @lgc.cooperative.matrix.insert.v8f32.v8f32.f16.i32.i32.i32(<8 x float>, half, i32, i32, i32) + +!llpc.compute.mode = !{!0} +!lgc.client = !{!1} +!lgc.options = !{!2} +!lgc.options.CS = !{!3} + +!0 = !{i32 128, i32 2, i32 1} +!1 = !{!"Vulkan"} +!2 = !{i32 -2108299168, i32 -1199997545, i32 1667044824, i32 -422575072, i32 1, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 1, i32 0, i32 0, i32 -1} +!3 = !{i32 219437737, i32 -1317595285, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 64, i32 64, i32 0, i32 0, i32 3, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 20, i32 1800} +!8 = !{i32 5} +!9 = !{i32 7} diff --git a/lgc/test/Transforms/LowerCooperativeMatrix/lit.local.cfg b/lgc/test/Transforms/LowerCooperativeMatrix/lit.local.cfg new file mode 100644 index 0000000000..a4266bc874 --- /dev/null +++ b/lgc/test/Transforms/LowerCooperativeMatrix/lit.local.cfg @@ -0,0 +1,2 @@ +if "vki_cooperative_matrix" not in config.available_features: + config.unsupported = True diff --git a/lgc/test/Transforms/LowerCooperativeMatrix/load-wave64.lgc b/lgc/test/Transforms/LowerCooperativeMatrix/load-wave64.lgc new file mode 100644 index 0000000000..44cde67243 --- /dev/null +++ b/lgc/test/Transforms/LowerCooperativeMatrix/load-wave64.lgc @@ -0,0 +1,345 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool lgc +; RUN: lgc -o - -passes=lgc-lower-cooperative-matrix %s | FileCheck --check-prefixes=CHECK %s + +define <8 x float> @test_f16_ab_layout(ptr addrspace(7) %ptr) !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_f16_ab_layout( +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = srem i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = add i32 0, [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = add i32 [[TMP4]], 0 +; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP5]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = getelementptr half, ptr addrspace(7) [[PTR:%.*]], i32 [[TMP6]] +; CHECK-NEXT: [[A1:%.*]] = load half, ptr addrspace(7) [[TMP7]], align 2 +; CHECK-NEXT: [[TMP8:%.*]] = insertelement <16 x half> poison, half [[A1]], i64 0 +; CHECK-NEXT: [[TMP9:%.*]] = add i32 [[TMP4]], 160 +; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[TMP9]], 0 +; CHECK-NEXT: [[TMP11:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP10]] +; CHECK-NEXT: [[A2:%.*]] = load half, ptr addrspace(7) [[TMP11]], align 2 +; CHECK-NEXT: [[TMP12:%.*]] = insertelement <16 x half> [[TMP8]], half [[A2]], i64 1 +; CHECK-NEXT: [[TMP13:%.*]] = add i32 [[TMP4]], 320 +; CHECK-NEXT: [[TMP14:%.*]] = add i32 [[TMP13]], 0 +; CHECK-NEXT: [[TMP15:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP14]] +; CHECK-NEXT: [[A3:%.*]] = load half, ptr addrspace(7) [[TMP15]], align 2 +; CHECK-NEXT: [[TMP16:%.*]] = insertelement <16 x half> [[TMP12]], half [[A3]], i64 2 +; CHECK-NEXT: [[TMP17:%.*]] = add i32 [[TMP4]], 480 +; CHECK-NEXT: [[TMP18:%.*]] = add i32 [[TMP17]], 0 +; CHECK-NEXT: [[TMP19:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP18]] +; CHECK-NEXT: [[A4:%.*]] = load half, ptr addrspace(7) [[TMP19]], align 2 +; CHECK-NEXT: [[TMP20:%.*]] = insertelement <16 x half> [[TMP16]], half [[A4]], i64 3 +; CHECK-NEXT: [[TMP21:%.*]] = add i32 [[TMP4]], 640 +; CHECK-NEXT: [[TMP22:%.*]] = add i32 [[TMP21]], 0 +; CHECK-NEXT: [[TMP23:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP22]] +; CHECK-NEXT: [[A5:%.*]] = load half, ptr addrspace(7) [[TMP23]], align 2 +; CHECK-NEXT: [[TMP24:%.*]] = insertelement <16 x half> [[TMP20]], half [[A5]], i64 4 +; CHECK-NEXT: [[TMP25:%.*]] = add i32 [[TMP4]], 800 +; CHECK-NEXT: [[TMP26:%.*]] = add i32 [[TMP25]], 0 +; CHECK-NEXT: [[TMP27:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP26]] +; CHECK-NEXT: [[A6:%.*]] = load half, ptr addrspace(7) [[TMP27]], align 2 +; CHECK-NEXT: [[TMP28:%.*]] = insertelement <16 x half> [[TMP24]], half [[A6]], i64 5 +; CHECK-NEXT: [[TMP29:%.*]] = add i32 [[TMP4]], 960 +; CHECK-NEXT: [[TMP30:%.*]] = add i32 [[TMP29]], 0 +; CHECK-NEXT: [[TMP31:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP30]] +; CHECK-NEXT: [[A7:%.*]] = load half, ptr addrspace(7) [[TMP31]], align 2 +; CHECK-NEXT: [[TMP32:%.*]] = insertelement <16 x half> [[TMP28]], half [[A7]], i64 6 +; CHECK-NEXT: [[TMP33:%.*]] = add i32 [[TMP4]], 1120 +; CHECK-NEXT: [[TMP34:%.*]] = add i32 [[TMP33]], 0 +; CHECK-NEXT: [[TMP35:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP34]] +; CHECK-NEXT: [[A8:%.*]] = load half, ptr addrspace(7) [[TMP35]], align 2 +; CHECK-NEXT: [[TMP36:%.*]] = insertelement <16 x half> [[TMP32]], half [[A8]], i64 7 +; CHECK-NEXT: [[TMP37:%.*]] = add i32 [[TMP4]], 1280 +; CHECK-NEXT: [[TMP38:%.*]] = add i32 [[TMP37]], 0 +; CHECK-NEXT: [[TMP39:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP38]] +; CHECK-NEXT: [[A9:%.*]] = load half, ptr addrspace(7) [[TMP39]], align 2 +; CHECK-NEXT: [[TMP40:%.*]] = insertelement <16 x half> [[TMP36]], half [[A9]], i64 8 +; CHECK-NEXT: [[TMP41:%.*]] = add i32 [[TMP4]], 1440 +; CHECK-NEXT: [[TMP42:%.*]] = add i32 [[TMP41]], 0 +; CHECK-NEXT: [[TMP43:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP42]] +; CHECK-NEXT: [[A10:%.*]] = load half, ptr addrspace(7) [[TMP43]], align 2 +; CHECK-NEXT: [[TMP44:%.*]] = insertelement <16 x half> [[TMP40]], half [[A10]], i64 9 +; CHECK-NEXT: [[TMP45:%.*]] = add i32 [[TMP4]], 1600 +; CHECK-NEXT: [[TMP46:%.*]] = add i32 [[TMP45]], 0 +; CHECK-NEXT: [[TMP47:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP46]] +; CHECK-NEXT: [[A11:%.*]] = load half, ptr addrspace(7) [[TMP47]], align 2 +; CHECK-NEXT: [[TMP48:%.*]] = insertelement <16 x half> [[TMP44]], half [[A11]], i64 10 +; CHECK-NEXT: [[TMP49:%.*]] = add i32 [[TMP4]], 1760 +; CHECK-NEXT: [[TMP50:%.*]] = add i32 [[TMP49]], 0 +; CHECK-NEXT: [[TMP51:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP50]] +; CHECK-NEXT: [[A12:%.*]] = load half, ptr addrspace(7) [[TMP51]], align 2 +; CHECK-NEXT: [[TMP52:%.*]] = insertelement <16 x half> [[TMP48]], half [[A12]], i64 11 +; CHECK-NEXT: [[TMP53:%.*]] = add i32 [[TMP4]], 1920 +; CHECK-NEXT: [[TMP54:%.*]] = add i32 [[TMP53]], 0 +; CHECK-NEXT: [[TMP55:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP54]] +; CHECK-NEXT: [[A13:%.*]] = load half, ptr addrspace(7) [[TMP55]], align 2 +; CHECK-NEXT: [[TMP56:%.*]] = insertelement <16 x half> [[TMP52]], half [[A13]], i64 12 +; CHECK-NEXT: [[TMP57:%.*]] = add i32 [[TMP4]], 2080 +; CHECK-NEXT: [[TMP58:%.*]] = add i32 [[TMP57]], 0 +; CHECK-NEXT: [[TMP59:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP58]] +; CHECK-NEXT: [[A14:%.*]] = load half, ptr addrspace(7) [[TMP59]], align 2 +; CHECK-NEXT: [[TMP60:%.*]] = insertelement <16 x half> [[TMP56]], half [[A14]], i64 13 +; CHECK-NEXT: [[TMP61:%.*]] = add i32 [[TMP4]], 2240 +; CHECK-NEXT: [[TMP62:%.*]] = add i32 [[TMP61]], 0 +; CHECK-NEXT: [[TMP63:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP62]] +; CHECK-NEXT: [[A15:%.*]] = load half, ptr addrspace(7) [[TMP63]], align 2 +; CHECK-NEXT: [[TMP64:%.*]] = insertelement <16 x half> [[TMP60]], half [[A15]], i64 14 +; CHECK-NEXT: [[TMP65:%.*]] = add i32 [[TMP4]], 2400 +; CHECK-NEXT: [[TMP66:%.*]] = add i32 [[TMP65]], 0 +; CHECK-NEXT: [[TMP67:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP66]] +; CHECK-NEXT: [[A16:%.*]] = load half, ptr addrspace(7) [[TMP67]], align 2 +; CHECK-NEXT: [[TMP68:%.*]] = insertelement <16 x half> [[TMP64]], half [[A16]], i64 15 +; CHECK-NEXT: [[TMP69:%.*]] = bitcast <16 x half> [[TMP68]] to <8 x float> +; CHECK-NEXT: ret <8 x float> [[TMP69]] +; + %a = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p7.i32.i1.i32.i32.i32(ptr addrspace(7) %ptr, i32 320, i1 false, i32 1, i32 0, i32 0) + ret <8 x float> %a +} + +define <8 x float> @test_f16_cd_layout(ptr addrspace(7) %ptr) !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_f16_cd_layout( +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = srem i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = udiv i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = mul i32 [[TMP4]], 160 +; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP5]], [[TMP3]] +; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[TMP6]], 0 +; CHECK-NEXT: [[TMP8:%.*]] = add i32 [[TMP7]], 0 +; CHECK-NEXT: [[TMP9:%.*]] = getelementptr half, ptr addrspace(7) [[PTR:%.*]], i32 [[TMP8]] +; CHECK-NEXT: [[A1:%.*]] = load half, ptr addrspace(7) [[TMP9]], align 2 +; CHECK-NEXT: [[TMP10:%.*]] = insertelement <4 x half> poison, half [[A1]], i64 0 +; CHECK-NEXT: [[TMP11:%.*]] = add i32 [[TMP6]], 640 +; CHECK-NEXT: [[TMP12:%.*]] = add i32 [[TMP11]], 0 +; CHECK-NEXT: [[TMP13:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP12]] +; CHECK-NEXT: [[A2:%.*]] = load half, ptr addrspace(7) [[TMP13]], align 2 +; CHECK-NEXT: [[TMP14:%.*]] = insertelement <4 x half> [[TMP10]], half [[A2]], i64 1 +; CHECK-NEXT: [[TMP15:%.*]] = add i32 [[TMP6]], 1280 +; CHECK-NEXT: [[TMP16:%.*]] = add i32 [[TMP15]], 0 +; CHECK-NEXT: [[TMP17:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP16]] +; CHECK-NEXT: [[A3:%.*]] = load half, ptr addrspace(7) [[TMP17]], align 2 +; CHECK-NEXT: [[TMP18:%.*]] = insertelement <4 x half> [[TMP14]], half [[A3]], i64 2 +; CHECK-NEXT: [[TMP19:%.*]] = add i32 [[TMP6]], 1920 +; CHECK-NEXT: [[TMP20:%.*]] = add i32 [[TMP19]], 0 +; CHECK-NEXT: [[TMP21:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP20]] +; CHECK-NEXT: [[A4:%.*]] = load half, ptr addrspace(7) [[TMP21]], align 2 +; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x half> [[TMP18]], half [[A4]], i64 3 +; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <4 x half> [[TMP22]], <4 x half> poison, <16 x i32> +; CHECK-NEXT: [[TMP24:%.*]] = bitcast <16 x half> [[TMP23]] to <8 x float> +; CHECK-NEXT: ret <8 x float> [[TMP24]] +; + %a = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p7.i32.i1.i32.i32.i32(ptr addrspace(7) %ptr, i32 320, i1 false, i32 1, i32 1, i32 0) + ret <8 x float> %a +} + +define <8 x i32> @test_i16_ab_layout(ptr addrspace(7) %ptr) !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_i16_ab_layout( +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = srem i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = add i32 0, [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = add i32 [[TMP4]], 0 +; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP5]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR:%.*]], i32 [[TMP6]] +; CHECK-NEXT: [[A1:%.*]] = load i16, ptr addrspace(7) [[TMP7]], align 2 +; CHECK-NEXT: [[TMP8:%.*]] = insertelement <16 x i16> poison, i16 [[A1]], i64 0 +; CHECK-NEXT: [[TMP9:%.*]] = add i32 [[TMP4]], 160 +; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[TMP9]], 0 +; CHECK-NEXT: [[TMP11:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP10]] +; CHECK-NEXT: [[A2:%.*]] = load i16, ptr addrspace(7) [[TMP11]], align 2 +; CHECK-NEXT: [[TMP12:%.*]] = insertelement <16 x i16> [[TMP8]], i16 [[A2]], i64 1 +; CHECK-NEXT: [[TMP13:%.*]] = add i32 [[TMP4]], 320 +; CHECK-NEXT: [[TMP14:%.*]] = add i32 [[TMP13]], 0 +; CHECK-NEXT: [[TMP15:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP14]] +; CHECK-NEXT: [[A3:%.*]] = load i16, ptr addrspace(7) [[TMP15]], align 2 +; CHECK-NEXT: [[TMP16:%.*]] = insertelement <16 x i16> [[TMP12]], i16 [[A3]], i64 2 +; CHECK-NEXT: [[TMP17:%.*]] = add i32 [[TMP4]], 480 +; CHECK-NEXT: [[TMP18:%.*]] = add i32 [[TMP17]], 0 +; CHECK-NEXT: [[TMP19:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP18]] +; CHECK-NEXT: [[A4:%.*]] = load i16, ptr addrspace(7) [[TMP19]], align 2 +; CHECK-NEXT: [[TMP20:%.*]] = insertelement <16 x i16> [[TMP16]], i16 [[A4]], i64 3 +; CHECK-NEXT: [[TMP21:%.*]] = add i32 [[TMP4]], 640 +; CHECK-NEXT: [[TMP22:%.*]] = add i32 [[TMP21]], 0 +; CHECK-NEXT: [[TMP23:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP22]] +; CHECK-NEXT: [[A5:%.*]] = load i16, ptr addrspace(7) [[TMP23]], align 2 +; CHECK-NEXT: [[TMP24:%.*]] = insertelement <16 x i16> [[TMP20]], i16 [[A5]], i64 4 +; CHECK-NEXT: [[TMP25:%.*]] = add i32 [[TMP4]], 800 +; CHECK-NEXT: [[TMP26:%.*]] = add i32 [[TMP25]], 0 +; CHECK-NEXT: [[TMP27:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP26]] +; CHECK-NEXT: [[A6:%.*]] = load i16, ptr addrspace(7) [[TMP27]], align 2 +; CHECK-NEXT: [[TMP28:%.*]] = insertelement <16 x i16> [[TMP24]], i16 [[A6]], i64 5 +; CHECK-NEXT: [[TMP29:%.*]] = add i32 [[TMP4]], 960 +; CHECK-NEXT: [[TMP30:%.*]] = add i32 [[TMP29]], 0 +; CHECK-NEXT: [[TMP31:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP30]] +; CHECK-NEXT: [[A7:%.*]] = load i16, ptr addrspace(7) [[TMP31]], align 2 +; CHECK-NEXT: [[TMP32:%.*]] = insertelement <16 x i16> [[TMP28]], i16 [[A7]], i64 6 +; CHECK-NEXT: [[TMP33:%.*]] = add i32 [[TMP4]], 1120 +; CHECK-NEXT: [[TMP34:%.*]] = add i32 [[TMP33]], 0 +; CHECK-NEXT: [[TMP35:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP34]] +; CHECK-NEXT: [[A8:%.*]] = load i16, ptr addrspace(7) [[TMP35]], align 2 +; CHECK-NEXT: [[TMP36:%.*]] = insertelement <16 x i16> [[TMP32]], i16 [[A8]], i64 7 +; CHECK-NEXT: [[TMP37:%.*]] = add i32 [[TMP4]], 1280 +; CHECK-NEXT: [[TMP38:%.*]] = add i32 [[TMP37]], 0 +; CHECK-NEXT: [[TMP39:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP38]] +; CHECK-NEXT: [[A9:%.*]] = load i16, ptr addrspace(7) [[TMP39]], align 2 +; CHECK-NEXT: [[TMP40:%.*]] = insertelement <16 x i16> [[TMP36]], i16 [[A9]], i64 8 +; CHECK-NEXT: [[TMP41:%.*]] = add i32 [[TMP4]], 1440 +; CHECK-NEXT: [[TMP42:%.*]] = add i32 [[TMP41]], 0 +; CHECK-NEXT: [[TMP43:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP42]] +; CHECK-NEXT: [[A10:%.*]] = load i16, ptr addrspace(7) [[TMP43]], align 2 +; CHECK-NEXT: [[TMP44:%.*]] = insertelement <16 x i16> [[TMP40]], i16 [[A10]], i64 9 +; CHECK-NEXT: [[TMP45:%.*]] = add i32 [[TMP4]], 1600 +; CHECK-NEXT: [[TMP46:%.*]] = add i32 [[TMP45]], 0 +; CHECK-NEXT: [[TMP47:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP46]] +; CHECK-NEXT: [[A11:%.*]] = load i16, ptr addrspace(7) [[TMP47]], align 2 +; CHECK-NEXT: [[TMP48:%.*]] = insertelement <16 x i16> [[TMP44]], i16 [[A11]], i64 10 +; CHECK-NEXT: [[TMP49:%.*]] = add i32 [[TMP4]], 1760 +; CHECK-NEXT: [[TMP50:%.*]] = add i32 [[TMP49]], 0 +; CHECK-NEXT: [[TMP51:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP50]] +; CHECK-NEXT: [[A12:%.*]] = load i16, ptr addrspace(7) [[TMP51]], align 2 +; CHECK-NEXT: [[TMP52:%.*]] = insertelement <16 x i16> [[TMP48]], i16 [[A12]], i64 11 +; CHECK-NEXT: [[TMP53:%.*]] = add i32 [[TMP4]], 1920 +; CHECK-NEXT: [[TMP54:%.*]] = add i32 [[TMP53]], 0 +; CHECK-NEXT: [[TMP55:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP54]] +; CHECK-NEXT: [[A13:%.*]] = load i16, ptr addrspace(7) [[TMP55]], align 2 +; CHECK-NEXT: [[TMP56:%.*]] = insertelement <16 x i16> [[TMP52]], i16 [[A13]], i64 12 +; CHECK-NEXT: [[TMP57:%.*]] = add i32 [[TMP4]], 2080 +; CHECK-NEXT: [[TMP58:%.*]] = add i32 [[TMP57]], 0 +; CHECK-NEXT: [[TMP59:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP58]] +; CHECK-NEXT: [[A14:%.*]] = load i16, ptr addrspace(7) [[TMP59]], align 2 +; CHECK-NEXT: [[TMP60:%.*]] = insertelement <16 x i16> [[TMP56]], i16 [[A14]], i64 13 +; CHECK-NEXT: [[TMP61:%.*]] = add i32 [[TMP4]], 2240 +; CHECK-NEXT: [[TMP62:%.*]] = add i32 [[TMP61]], 0 +; CHECK-NEXT: [[TMP63:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP62]] +; CHECK-NEXT: [[A15:%.*]] = load i16, ptr addrspace(7) [[TMP63]], align 2 +; CHECK-NEXT: [[TMP64:%.*]] = insertelement <16 x i16> [[TMP60]], i16 [[A15]], i64 14 +; CHECK-NEXT: [[TMP65:%.*]] = add i32 [[TMP4]], 2400 +; CHECK-NEXT: [[TMP66:%.*]] = add i32 [[TMP65]], 0 +; CHECK-NEXT: [[TMP67:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP66]] +; CHECK-NEXT: [[A16:%.*]] = load i16, ptr addrspace(7) [[TMP67]], align 2 +; CHECK-NEXT: [[TMP68:%.*]] = insertelement <16 x i16> [[TMP64]], i16 [[A16]], i64 15 +; CHECK-NEXT: [[TMP69:%.*]] = bitcast <16 x i16> [[TMP68]] to <8 x i32> +; CHECK-NEXT: ret <8 x i32> [[TMP69]] +; + %a = call <8 x i32> @lgc.cooperative.matrix.load.v8i32.p7.i32.i1.i32.i32.i32(ptr addrspace(7) %ptr, i32 320, i1 false, i32 4, i32 0, i32 0) + ret <8 x i32> %a +} + +define <8 x i32> @test_i16_cd_layout(ptr addrspace(7) %ptr) !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_i16_cd_layout( +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = srem i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = udiv i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = mul i32 [[TMP4]], 160 +; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP5]], [[TMP3]] +; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[TMP6]], 0 +; CHECK-NEXT: [[TMP8:%.*]] = add i32 [[TMP7]], 0 +; CHECK-NEXT: [[TMP9:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR:%.*]], i32 [[TMP8]] +; CHECK-NEXT: [[A1:%.*]] = load i16, ptr addrspace(7) [[TMP9]], align 2 +; CHECK-NEXT: [[TMP10:%.*]] = insertelement <4 x i16> poison, i16 [[A1]], i64 0 +; CHECK-NEXT: [[TMP11:%.*]] = add i32 [[TMP6]], 640 +; CHECK-NEXT: [[TMP12:%.*]] = add i32 [[TMP11]], 0 +; CHECK-NEXT: [[TMP13:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP12]] +; CHECK-NEXT: [[A2:%.*]] = load i16, ptr addrspace(7) [[TMP13]], align 2 +; CHECK-NEXT: [[TMP14:%.*]] = insertelement <4 x i16> [[TMP10]], i16 [[A2]], i64 1 +; CHECK-NEXT: [[TMP15:%.*]] = add i32 [[TMP6]], 1280 +; CHECK-NEXT: [[TMP16:%.*]] = add i32 [[TMP15]], 0 +; CHECK-NEXT: [[TMP17:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP16]] +; CHECK-NEXT: [[A3:%.*]] = load i16, ptr addrspace(7) [[TMP17]], align 2 +; CHECK-NEXT: [[TMP18:%.*]] = insertelement <4 x i16> [[TMP14]], i16 [[A3]], i64 2 +; CHECK-NEXT: [[TMP19:%.*]] = add i32 [[TMP6]], 1920 +; CHECK-NEXT: [[TMP20:%.*]] = add i32 [[TMP19]], 0 +; CHECK-NEXT: [[TMP21:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP20]] +; CHECK-NEXT: [[A4:%.*]] = load i16, ptr addrspace(7) [[TMP21]], align 2 +; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i16> [[TMP18]], i16 [[A4]], i64 3 +; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <4 x i16> [[TMP22]], <4 x i16> poison, <16 x i32> +; CHECK-NEXT: [[TMP24:%.*]] = bitcast <16 x i16> [[TMP23]] to <8 x i32> +; CHECK-NEXT: ret <8 x i32> [[TMP24]] +; + %a = call <8 x i32> @lgc.cooperative.matrix.load.v8i32.p7.i32.i1.i32.i32.i32(ptr addrspace(7) %ptr, i32 320, i1 false, i32 4, i32 1, i32 0) + ret <8 x i32> %a +} + +define <8 x float> @test_f32_cd_layout(ptr addrspace(7) %ptr) !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_f32_cd_layout( +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = srem i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = udiv i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = mul i32 [[TMP4]], 160 +; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP5]], [[TMP3]] +; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[TMP6]], 0 +; CHECK-NEXT: [[TMP8:%.*]] = add i32 [[TMP7]], 0 +; CHECK-NEXT: [[TMP9:%.*]] = getelementptr float, ptr addrspace(7) [[PTR:%.*]], i32 [[TMP8]] +; CHECK-NEXT: [[A1:%.*]] = load float, ptr addrspace(7) [[TMP9]], align 4 +; CHECK-NEXT: [[TMP10:%.*]] = insertelement <4 x float> poison, float [[A1]], i64 0 +; CHECK-NEXT: [[TMP11:%.*]] = add i32 [[TMP6]], 640 +; CHECK-NEXT: [[TMP12:%.*]] = add i32 [[TMP11]], 0 +; CHECK-NEXT: [[TMP13:%.*]] = getelementptr float, ptr addrspace(7) [[PTR]], i32 [[TMP12]] +; CHECK-NEXT: [[A2:%.*]] = load float, ptr addrspace(7) [[TMP13]], align 4 +; CHECK-NEXT: [[TMP14:%.*]] = insertelement <4 x float> [[TMP10]], float [[A2]], i64 1 +; CHECK-NEXT: [[TMP15:%.*]] = add i32 [[TMP6]], 1280 +; CHECK-NEXT: [[TMP16:%.*]] = add i32 [[TMP15]], 0 +; CHECK-NEXT: [[TMP17:%.*]] = getelementptr float, ptr addrspace(7) [[PTR]], i32 [[TMP16]] +; CHECK-NEXT: [[A3:%.*]] = load float, ptr addrspace(7) [[TMP17]], align 4 +; CHECK-NEXT: [[TMP18:%.*]] = insertelement <4 x float> [[TMP14]], float [[A3]], i64 2 +; CHECK-NEXT: [[TMP19:%.*]] = add i32 [[TMP6]], 1920 +; CHECK-NEXT: [[TMP20:%.*]] = add i32 [[TMP19]], 0 +; CHECK-NEXT: [[TMP21:%.*]] = getelementptr float, ptr addrspace(7) [[PTR]], i32 [[TMP20]] +; CHECK-NEXT: [[A4:%.*]] = load float, ptr addrspace(7) [[TMP21]], align 4 +; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x float> [[TMP18]], float [[A4]], i64 3 +; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <4 x float> [[TMP22]], <4 x float> poison, <8 x i32> +; CHECK-NEXT: ret <8 x float> [[TMP23]] +; + %a = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p7.i32.i1.i32.i32.i32(ptr addrspace(7) %ptr, i32 640, i1 false, i32 2, i32 1, i32 0) + ret <8 x float> %a +} + +define <8 x i32> @test_i32_cd_layout(ptr addrspace(7) %ptr) !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_i32_cd_layout( +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = srem i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = udiv i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = mul i32 [[TMP4]], 160 +; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP5]], [[TMP3]] +; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[TMP6]], 0 +; CHECK-NEXT: [[TMP8:%.*]] = add i32 [[TMP7]], 0 +; CHECK-NEXT: [[TMP9:%.*]] = getelementptr i32, ptr addrspace(7) [[PTR:%.*]], i32 [[TMP8]] +; CHECK-NEXT: [[A1:%.*]] = load i32, ptr addrspace(7) [[TMP9]], align 4 +; CHECK-NEXT: [[TMP10:%.*]] = insertelement <4 x i32> poison, i32 [[A1]], i64 0 +; CHECK-NEXT: [[TMP11:%.*]] = add i32 [[TMP6]], 640 +; CHECK-NEXT: [[TMP12:%.*]] = add i32 [[TMP11]], 0 +; CHECK-NEXT: [[TMP13:%.*]] = getelementptr i32, ptr addrspace(7) [[PTR]], i32 [[TMP12]] +; CHECK-NEXT: [[A2:%.*]] = load i32, ptr addrspace(7) [[TMP13]], align 4 +; CHECK-NEXT: [[TMP14:%.*]] = insertelement <4 x i32> [[TMP10]], i32 [[A2]], i64 1 +; CHECK-NEXT: [[TMP15:%.*]] = add i32 [[TMP6]], 1280 +; CHECK-NEXT: [[TMP16:%.*]] = add i32 [[TMP15]], 0 +; CHECK-NEXT: [[TMP17:%.*]] = getelementptr i32, ptr addrspace(7) [[PTR]], i32 [[TMP16]] +; CHECK-NEXT: [[A3:%.*]] = load i32, ptr addrspace(7) [[TMP17]], align 4 +; CHECK-NEXT: [[TMP18:%.*]] = insertelement <4 x i32> [[TMP14]], i32 [[A3]], i64 2 +; CHECK-NEXT: [[TMP19:%.*]] = add i32 [[TMP6]], 1920 +; CHECK-NEXT: [[TMP20:%.*]] = add i32 [[TMP19]], 0 +; CHECK-NEXT: [[TMP21:%.*]] = getelementptr i32, ptr addrspace(7) [[PTR]], i32 [[TMP20]] +; CHECK-NEXT: [[A4:%.*]] = load i32, ptr addrspace(7) [[TMP21]], align 4 +; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i32> [[TMP18]], i32 [[A4]], i64 3 +; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <4 x i32> [[TMP22]], <4 x i32> poison, <8 x i32> +; CHECK-NEXT: ret <8 x i32> [[TMP23]] +; + %a = call <8 x i32> @lgc.cooperative.matrix.load.v8i32.p7.i32.i1.i32.i32.i32(ptr addrspace(7) %ptr, i32 640, i1 false, i32 5, i32 1, i32 0) + ret <8 x i32> %a +} + +declare <8 x float> @lgc.cooperative.matrix.load.v8f32.p7.i32.i1.i32.i32.i32(ptr addrspace(7), i32, i1, i32, i32, i32) +declare <8 x i32> @lgc.cooperative.matrix.load.v8i32.p7.i32.i1.i32.i32.i32(ptr addrspace(7), i32, i1, i32, i32, i32) + +!llpc.compute.mode = !{!0} +!lgc.client = !{!1} +!lgc.options = !{!2} +!lgc.options.CS = !{!3} + +!0 = !{i32 128, i32 2, i32 1} +!1 = !{!"Vulkan"} +!2 = !{i32 -2108299168, i32 -1199997545, i32 1667044824, i32 -422575072, i32 1, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 1, i32 0, i32 0, i32 -1} +!3 = !{i32 219437737, i32 -1317595285, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 64, i32 64, i32 0, i32 0, i32 3, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 20, i32 1800} +!8 = !{i32 5} +!9 = !{i32 7} diff --git a/lgc/test/Transforms/LowerCooperativeMatrix/store-wave64.lgc b/lgc/test/Transforms/LowerCooperativeMatrix/store-wave64.lgc new file mode 100644 index 0000000000..798c807644 --- /dev/null +++ b/lgc/test/Transforms/LowerCooperativeMatrix/store-wave64.lgc @@ -0,0 +1,345 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool lgc +; RUN: lgc -o - -passes=lgc-lower-cooperative-matrix %s | FileCheck --check-prefixes=CHECK %s + +define void @test_f16_ab_layout(ptr addrspace(7) %ptr, <8 x float> %a) !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_f16_ab_layout( +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = srem i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = add i32 0, [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = bitcast <8 x float> [[A:%.*]] to <16 x half> +; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP4]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[TMP6]], 0 +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr half, ptr addrspace(7) [[PTR:%.*]], i32 [[TMP7]] +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <16 x half> [[TMP5]], i64 0 +; CHECK-NEXT: store half [[TMP9]], ptr addrspace(7) [[TMP8]], align 2 +; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[TMP4]], 160 +; CHECK-NEXT: [[TMP11:%.*]] = add i32 [[TMP10]], 0 +; CHECK-NEXT: [[TMP12:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP11]] +; CHECK-NEXT: [[TMP13:%.*]] = extractelement <16 x half> [[TMP5]], i64 1 +; CHECK-NEXT: store half [[TMP13]], ptr addrspace(7) [[TMP12]], align 2 +; CHECK-NEXT: [[TMP14:%.*]] = add i32 [[TMP4]], 320 +; CHECK-NEXT: [[TMP15:%.*]] = add i32 [[TMP14]], 0 +; CHECK-NEXT: [[TMP16:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP15]] +; CHECK-NEXT: [[TMP17:%.*]] = extractelement <16 x half> [[TMP5]], i64 2 +; CHECK-NEXT: store half [[TMP17]], ptr addrspace(7) [[TMP16]], align 2 +; CHECK-NEXT: [[TMP18:%.*]] = add i32 [[TMP4]], 480 +; CHECK-NEXT: [[TMP19:%.*]] = add i32 [[TMP18]], 0 +; CHECK-NEXT: [[TMP20:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP19]] +; CHECK-NEXT: [[TMP21:%.*]] = extractelement <16 x half> [[TMP5]], i64 3 +; CHECK-NEXT: store half [[TMP21]], ptr addrspace(7) [[TMP20]], align 2 +; CHECK-NEXT: [[TMP22:%.*]] = add i32 [[TMP4]], 640 +; CHECK-NEXT: [[TMP23:%.*]] = add i32 [[TMP22]], 0 +; CHECK-NEXT: [[TMP24:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP23]] +; CHECK-NEXT: [[TMP25:%.*]] = extractelement <16 x half> [[TMP5]], i64 4 +; CHECK-NEXT: store half [[TMP25]], ptr addrspace(7) [[TMP24]], align 2 +; CHECK-NEXT: [[TMP26:%.*]] = add i32 [[TMP4]], 800 +; CHECK-NEXT: [[TMP27:%.*]] = add i32 [[TMP26]], 0 +; CHECK-NEXT: [[TMP28:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP27]] +; CHECK-NEXT: [[TMP29:%.*]] = extractelement <16 x half> [[TMP5]], i64 5 +; CHECK-NEXT: store half [[TMP29]], ptr addrspace(7) [[TMP28]], align 2 +; CHECK-NEXT: [[TMP30:%.*]] = add i32 [[TMP4]], 960 +; CHECK-NEXT: [[TMP31:%.*]] = add i32 [[TMP30]], 0 +; CHECK-NEXT: [[TMP32:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP31]] +; CHECK-NEXT: [[TMP33:%.*]] = extractelement <16 x half> [[TMP5]], i64 6 +; CHECK-NEXT: store half [[TMP33]], ptr addrspace(7) [[TMP32]], align 2 +; CHECK-NEXT: [[TMP34:%.*]] = add i32 [[TMP4]], 1120 +; CHECK-NEXT: [[TMP35:%.*]] = add i32 [[TMP34]], 0 +; CHECK-NEXT: [[TMP36:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP35]] +; CHECK-NEXT: [[TMP37:%.*]] = extractelement <16 x half> [[TMP5]], i64 7 +; CHECK-NEXT: store half [[TMP37]], ptr addrspace(7) [[TMP36]], align 2 +; CHECK-NEXT: [[TMP38:%.*]] = add i32 [[TMP4]], 1280 +; CHECK-NEXT: [[TMP39:%.*]] = add i32 [[TMP38]], 0 +; CHECK-NEXT: [[TMP40:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP39]] +; CHECK-NEXT: [[TMP41:%.*]] = extractelement <16 x half> [[TMP5]], i64 8 +; CHECK-NEXT: store half [[TMP41]], ptr addrspace(7) [[TMP40]], align 2 +; CHECK-NEXT: [[TMP42:%.*]] = add i32 [[TMP4]], 1440 +; CHECK-NEXT: [[TMP43:%.*]] = add i32 [[TMP42]], 0 +; CHECK-NEXT: [[TMP44:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP43]] +; CHECK-NEXT: [[TMP45:%.*]] = extractelement <16 x half> [[TMP5]], i64 9 +; CHECK-NEXT: store half [[TMP45]], ptr addrspace(7) [[TMP44]], align 2 +; CHECK-NEXT: [[TMP46:%.*]] = add i32 [[TMP4]], 1600 +; CHECK-NEXT: [[TMP47:%.*]] = add i32 [[TMP46]], 0 +; CHECK-NEXT: [[TMP48:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP47]] +; CHECK-NEXT: [[TMP49:%.*]] = extractelement <16 x half> [[TMP5]], i64 10 +; CHECK-NEXT: store half [[TMP49]], ptr addrspace(7) [[TMP48]], align 2 +; CHECK-NEXT: [[TMP50:%.*]] = add i32 [[TMP4]], 1760 +; CHECK-NEXT: [[TMP51:%.*]] = add i32 [[TMP50]], 0 +; CHECK-NEXT: [[TMP52:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP51]] +; CHECK-NEXT: [[TMP53:%.*]] = extractelement <16 x half> [[TMP5]], i64 11 +; CHECK-NEXT: store half [[TMP53]], ptr addrspace(7) [[TMP52]], align 2 +; CHECK-NEXT: [[TMP54:%.*]] = add i32 [[TMP4]], 1920 +; CHECK-NEXT: [[TMP55:%.*]] = add i32 [[TMP54]], 0 +; CHECK-NEXT: [[TMP56:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP55]] +; CHECK-NEXT: [[TMP57:%.*]] = extractelement <16 x half> [[TMP5]], i64 12 +; CHECK-NEXT: store half [[TMP57]], ptr addrspace(7) [[TMP56]], align 2 +; CHECK-NEXT: [[TMP58:%.*]] = add i32 [[TMP4]], 2080 +; CHECK-NEXT: [[TMP59:%.*]] = add i32 [[TMP58]], 0 +; CHECK-NEXT: [[TMP60:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP59]] +; CHECK-NEXT: [[TMP61:%.*]] = extractelement <16 x half> [[TMP5]], i64 13 +; CHECK-NEXT: store half [[TMP61]], ptr addrspace(7) [[TMP60]], align 2 +; CHECK-NEXT: [[TMP62:%.*]] = add i32 [[TMP4]], 2240 +; CHECK-NEXT: [[TMP63:%.*]] = add i32 [[TMP62]], 0 +; CHECK-NEXT: [[TMP64:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP63]] +; CHECK-NEXT: [[TMP65:%.*]] = extractelement <16 x half> [[TMP5]], i64 14 +; CHECK-NEXT: store half [[TMP65]], ptr addrspace(7) [[TMP64]], align 2 +; CHECK-NEXT: [[TMP66:%.*]] = add i32 [[TMP4]], 2400 +; CHECK-NEXT: [[TMP67:%.*]] = add i32 [[TMP66]], 0 +; CHECK-NEXT: [[TMP68:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP67]] +; CHECK-NEXT: [[TMP69:%.*]] = extractelement <16 x half> [[TMP5]], i64 15 +; CHECK-NEXT: store half [[TMP69]], ptr addrspace(7) [[TMP68]], align 2 +; CHECK-NEXT: ret void +; + call void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8f32(ptr addrspace(7) %ptr, i32 320, i1 false, i32 1, i32 0, i32 0, <8 x float> %a) + ret void +} + +define void @test_f16_cd_layout(ptr addrspace(7) %ptr, <8 x float> %a) !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_f16_cd_layout( +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = srem i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = udiv i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = mul i32 [[TMP4]], 160 +; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP5]], [[TMP3]] +; CHECK-NEXT: [[TMP7:%.*]] = bitcast <8 x float> [[A:%.*]] to <16 x half> +; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <16 x half> [[TMP7]], <16 x half> poison, <4 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = add i32 [[TMP6]], 0 +; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[TMP9]], 0 +; CHECK-NEXT: [[TMP11:%.*]] = getelementptr half, ptr addrspace(7) [[PTR:%.*]], i32 [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <4 x half> [[TMP8]], i64 0 +; CHECK-NEXT: store half [[TMP12]], ptr addrspace(7) [[TMP11]], align 2 +; CHECK-NEXT: [[TMP13:%.*]] = add i32 [[TMP6]], 640 +; CHECK-NEXT: [[TMP14:%.*]] = add i32 [[TMP13]], 0 +; CHECK-NEXT: [[TMP15:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP14]] +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <4 x half> [[TMP8]], i64 1 +; CHECK-NEXT: store half [[TMP16]], ptr addrspace(7) [[TMP15]], align 2 +; CHECK-NEXT: [[TMP17:%.*]] = add i32 [[TMP6]], 1280 +; CHECK-NEXT: [[TMP18:%.*]] = add i32 [[TMP17]], 0 +; CHECK-NEXT: [[TMP19:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP18]] +; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x half> [[TMP8]], i64 2 +; CHECK-NEXT: store half [[TMP20]], ptr addrspace(7) [[TMP19]], align 2 +; CHECK-NEXT: [[TMP21:%.*]] = add i32 [[TMP6]], 1920 +; CHECK-NEXT: [[TMP22:%.*]] = add i32 [[TMP21]], 0 +; CHECK-NEXT: [[TMP23:%.*]] = getelementptr half, ptr addrspace(7) [[PTR]], i32 [[TMP22]] +; CHECK-NEXT: [[TMP24:%.*]] = extractelement <4 x half> [[TMP8]], i64 3 +; CHECK-NEXT: store half [[TMP24]], ptr addrspace(7) [[TMP23]], align 2 +; CHECK-NEXT: ret void +; + call void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8f32(ptr addrspace(7) %ptr, i32 320, i1 false, i32 1, i32 1, i32 0, <8 x float> %a) + ret void +} + +define void @test_i16_ab_layout(ptr addrspace(7) %ptr, <8 x i32> %a) !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_i16_ab_layout( +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = srem i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = add i32 0, [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = bitcast <8 x i32> [[A:%.*]] to <16 x i16> +; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP4]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[TMP6]], 0 +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR:%.*]], i32 [[TMP7]] +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <16 x i16> [[TMP5]], i64 0 +; CHECK-NEXT: store i16 [[TMP9]], ptr addrspace(7) [[TMP8]], align 2 +; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[TMP4]], 160 +; CHECK-NEXT: [[TMP11:%.*]] = add i32 [[TMP10]], 0 +; CHECK-NEXT: [[TMP12:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP11]] +; CHECK-NEXT: [[TMP13:%.*]] = extractelement <16 x i16> [[TMP5]], i64 1 +; CHECK-NEXT: store i16 [[TMP13]], ptr addrspace(7) [[TMP12]], align 2 +; CHECK-NEXT: [[TMP14:%.*]] = add i32 [[TMP4]], 320 +; CHECK-NEXT: [[TMP15:%.*]] = add i32 [[TMP14]], 0 +; CHECK-NEXT: [[TMP16:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP15]] +; CHECK-NEXT: [[TMP17:%.*]] = extractelement <16 x i16> [[TMP5]], i64 2 +; CHECK-NEXT: store i16 [[TMP17]], ptr addrspace(7) [[TMP16]], align 2 +; CHECK-NEXT: [[TMP18:%.*]] = add i32 [[TMP4]], 480 +; CHECK-NEXT: [[TMP19:%.*]] = add i32 [[TMP18]], 0 +; CHECK-NEXT: [[TMP20:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP19]] +; CHECK-NEXT: [[TMP21:%.*]] = extractelement <16 x i16> [[TMP5]], i64 3 +; CHECK-NEXT: store i16 [[TMP21]], ptr addrspace(7) [[TMP20]], align 2 +; CHECK-NEXT: [[TMP22:%.*]] = add i32 [[TMP4]], 640 +; CHECK-NEXT: [[TMP23:%.*]] = add i32 [[TMP22]], 0 +; CHECK-NEXT: [[TMP24:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP23]] +; CHECK-NEXT: [[TMP25:%.*]] = extractelement <16 x i16> [[TMP5]], i64 4 +; CHECK-NEXT: store i16 [[TMP25]], ptr addrspace(7) [[TMP24]], align 2 +; CHECK-NEXT: [[TMP26:%.*]] = add i32 [[TMP4]], 800 +; CHECK-NEXT: [[TMP27:%.*]] = add i32 [[TMP26]], 0 +; CHECK-NEXT: [[TMP28:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP27]] +; CHECK-NEXT: [[TMP29:%.*]] = extractelement <16 x i16> [[TMP5]], i64 5 +; CHECK-NEXT: store i16 [[TMP29]], ptr addrspace(7) [[TMP28]], align 2 +; CHECK-NEXT: [[TMP30:%.*]] = add i32 [[TMP4]], 960 +; CHECK-NEXT: [[TMP31:%.*]] = add i32 [[TMP30]], 0 +; CHECK-NEXT: [[TMP32:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP31]] +; CHECK-NEXT: [[TMP33:%.*]] = extractelement <16 x i16> [[TMP5]], i64 6 +; CHECK-NEXT: store i16 [[TMP33]], ptr addrspace(7) [[TMP32]], align 2 +; CHECK-NEXT: [[TMP34:%.*]] = add i32 [[TMP4]], 1120 +; CHECK-NEXT: [[TMP35:%.*]] = add i32 [[TMP34]], 0 +; CHECK-NEXT: [[TMP36:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP35]] +; CHECK-NEXT: [[TMP37:%.*]] = extractelement <16 x i16> [[TMP5]], i64 7 +; CHECK-NEXT: store i16 [[TMP37]], ptr addrspace(7) [[TMP36]], align 2 +; CHECK-NEXT: [[TMP38:%.*]] = add i32 [[TMP4]], 1280 +; CHECK-NEXT: [[TMP39:%.*]] = add i32 [[TMP38]], 0 +; CHECK-NEXT: [[TMP40:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP39]] +; CHECK-NEXT: [[TMP41:%.*]] = extractelement <16 x i16> [[TMP5]], i64 8 +; CHECK-NEXT: store i16 [[TMP41]], ptr addrspace(7) [[TMP40]], align 2 +; CHECK-NEXT: [[TMP42:%.*]] = add i32 [[TMP4]], 1440 +; CHECK-NEXT: [[TMP43:%.*]] = add i32 [[TMP42]], 0 +; CHECK-NEXT: [[TMP44:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP43]] +; CHECK-NEXT: [[TMP45:%.*]] = extractelement <16 x i16> [[TMP5]], i64 9 +; CHECK-NEXT: store i16 [[TMP45]], ptr addrspace(7) [[TMP44]], align 2 +; CHECK-NEXT: [[TMP46:%.*]] = add i32 [[TMP4]], 1600 +; CHECK-NEXT: [[TMP47:%.*]] = add i32 [[TMP46]], 0 +; CHECK-NEXT: [[TMP48:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP47]] +; CHECK-NEXT: [[TMP49:%.*]] = extractelement <16 x i16> [[TMP5]], i64 10 +; CHECK-NEXT: store i16 [[TMP49]], ptr addrspace(7) [[TMP48]], align 2 +; CHECK-NEXT: [[TMP50:%.*]] = add i32 [[TMP4]], 1760 +; CHECK-NEXT: [[TMP51:%.*]] = add i32 [[TMP50]], 0 +; CHECK-NEXT: [[TMP52:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP51]] +; CHECK-NEXT: [[TMP53:%.*]] = extractelement <16 x i16> [[TMP5]], i64 11 +; CHECK-NEXT: store i16 [[TMP53]], ptr addrspace(7) [[TMP52]], align 2 +; CHECK-NEXT: [[TMP54:%.*]] = add i32 [[TMP4]], 1920 +; CHECK-NEXT: [[TMP55:%.*]] = add i32 [[TMP54]], 0 +; CHECK-NEXT: [[TMP56:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP55]] +; CHECK-NEXT: [[TMP57:%.*]] = extractelement <16 x i16> [[TMP5]], i64 12 +; CHECK-NEXT: store i16 [[TMP57]], ptr addrspace(7) [[TMP56]], align 2 +; CHECK-NEXT: [[TMP58:%.*]] = add i32 [[TMP4]], 2080 +; CHECK-NEXT: [[TMP59:%.*]] = add i32 [[TMP58]], 0 +; CHECK-NEXT: [[TMP60:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP59]] +; CHECK-NEXT: [[TMP61:%.*]] = extractelement <16 x i16> [[TMP5]], i64 13 +; CHECK-NEXT: store i16 [[TMP61]], ptr addrspace(7) [[TMP60]], align 2 +; CHECK-NEXT: [[TMP62:%.*]] = add i32 [[TMP4]], 2240 +; CHECK-NEXT: [[TMP63:%.*]] = add i32 [[TMP62]], 0 +; CHECK-NEXT: [[TMP64:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP63]] +; CHECK-NEXT: [[TMP65:%.*]] = extractelement <16 x i16> [[TMP5]], i64 14 +; CHECK-NEXT: store i16 [[TMP65]], ptr addrspace(7) [[TMP64]], align 2 +; CHECK-NEXT: [[TMP66:%.*]] = add i32 [[TMP4]], 2400 +; CHECK-NEXT: [[TMP67:%.*]] = add i32 [[TMP66]], 0 +; CHECK-NEXT: [[TMP68:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP67]] +; CHECK-NEXT: [[TMP69:%.*]] = extractelement <16 x i16> [[TMP5]], i64 15 +; CHECK-NEXT: store i16 [[TMP69]], ptr addrspace(7) [[TMP68]], align 2 +; CHECK-NEXT: ret void +; + call void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8i32(ptr addrspace(7) %ptr, i32 320, i1 false, i32 4, i32 0, i32 0, <8 x i32> %a) + ret void +} + +define void @test_i16_cd_layout(ptr addrspace(7) %ptr, <8 x i32> %a) !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_i16_cd_layout( +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = srem i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = udiv i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = mul i32 [[TMP4]], 160 +; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP5]], [[TMP3]] +; CHECK-NEXT: [[TMP7:%.*]] = bitcast <8 x i32> [[A:%.*]] to <16 x i16> +; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <16 x i16> [[TMP7]], <16 x i16> poison, <4 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = add i32 [[TMP6]], 0 +; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[TMP9]], 0 +; CHECK-NEXT: [[TMP11:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR:%.*]], i32 [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <4 x i16> [[TMP8]], i64 0 +; CHECK-NEXT: store i16 [[TMP12]], ptr addrspace(7) [[TMP11]], align 2 +; CHECK-NEXT: [[TMP13:%.*]] = add i32 [[TMP6]], 640 +; CHECK-NEXT: [[TMP14:%.*]] = add i32 [[TMP13]], 0 +; CHECK-NEXT: [[TMP15:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP14]] +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <4 x i16> [[TMP8]], i64 1 +; CHECK-NEXT: store i16 [[TMP16]], ptr addrspace(7) [[TMP15]], align 2 +; CHECK-NEXT: [[TMP17:%.*]] = add i32 [[TMP6]], 1280 +; CHECK-NEXT: [[TMP18:%.*]] = add i32 [[TMP17]], 0 +; CHECK-NEXT: [[TMP19:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP18]] +; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x i16> [[TMP8]], i64 2 +; CHECK-NEXT: store i16 [[TMP20]], ptr addrspace(7) [[TMP19]], align 2 +; CHECK-NEXT: [[TMP21:%.*]] = add i32 [[TMP6]], 1920 +; CHECK-NEXT: [[TMP22:%.*]] = add i32 [[TMP21]], 0 +; CHECK-NEXT: [[TMP23:%.*]] = getelementptr i16, ptr addrspace(7) [[PTR]], i32 [[TMP22]] +; CHECK-NEXT: [[TMP24:%.*]] = extractelement <4 x i16> [[TMP8]], i64 3 +; CHECK-NEXT: store i16 [[TMP24]], ptr addrspace(7) [[TMP23]], align 2 +; CHECK-NEXT: ret void +; + call void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8i32(ptr addrspace(7) %ptr, i32 320, i1 false, i32 4, i32 1, i32 0, <8 x i32> %a) + ret void +} + +define void @test_f32_cd_layout(ptr addrspace(7) %ptr, <8 x float> %a) !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_f32_cd_layout( +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = srem i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = udiv i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = mul i32 [[TMP4]], 160 +; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP5]], [[TMP3]] +; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> poison, <4 x i32> +; CHECK-NEXT: [[TMP8:%.*]] = add i32 [[TMP6]], 0 +; CHECK-NEXT: [[TMP9:%.*]] = add i32 [[TMP8]], 0 +; CHECK-NEXT: [[TMP10:%.*]] = getelementptr float, ptr addrspace(7) [[PTR:%.*]], i32 [[TMP9]] +; CHECK-NEXT: [[TMP11:%.*]] = extractelement <4 x float> [[TMP7]], i64 0 +; CHECK-NEXT: store float [[TMP11]], ptr addrspace(7) [[TMP10]], align 4 +; CHECK-NEXT: [[TMP12:%.*]] = add i32 [[TMP6]], 640 +; CHECK-NEXT: [[TMP13:%.*]] = add i32 [[TMP12]], 0 +; CHECK-NEXT: [[TMP14:%.*]] = getelementptr float, ptr addrspace(7) [[PTR]], i32 [[TMP13]] +; CHECK-NEXT: [[TMP15:%.*]] = extractelement <4 x float> [[TMP7]], i64 1 +; CHECK-NEXT: store float [[TMP15]], ptr addrspace(7) [[TMP14]], align 4 +; CHECK-NEXT: [[TMP16:%.*]] = add i32 [[TMP6]], 1280 +; CHECK-NEXT: [[TMP17:%.*]] = add i32 [[TMP16]], 0 +; CHECK-NEXT: [[TMP18:%.*]] = getelementptr float, ptr addrspace(7) [[PTR]], i32 [[TMP17]] +; CHECK-NEXT: [[TMP19:%.*]] = extractelement <4 x float> [[TMP7]], i64 2 +; CHECK-NEXT: store float [[TMP19]], ptr addrspace(7) [[TMP18]], align 4 +; CHECK-NEXT: [[TMP20:%.*]] = add i32 [[TMP6]], 1920 +; CHECK-NEXT: [[TMP21:%.*]] = add i32 [[TMP20]], 0 +; CHECK-NEXT: [[TMP22:%.*]] = getelementptr float, ptr addrspace(7) [[PTR]], i32 [[TMP21]] +; CHECK-NEXT: [[TMP23:%.*]] = extractelement <4 x float> [[TMP7]], i64 3 +; CHECK-NEXT: store float [[TMP23]], ptr addrspace(7) [[TMP22]], align 4 +; CHECK-NEXT: ret void +; + call void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8f32(ptr addrspace(7) %ptr, i32 640, i1 false, i32 2, i32 1, i32 0, <8 x float> %a) + ret void +} + +define void @test_i32_cd_layout(ptr addrspace(7) %ptr, <8 x i32> %a) !spirv.ExecutionModel !8 !lgc.shaderstage !9 { +; CHECK-LABEL: @test_i32_cd_layout( +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = srem i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = udiv i32 [[TMP2]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = mul i32 [[TMP4]], 160 +; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP5]], [[TMP3]] +; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <4 x i32> +; CHECK-NEXT: [[TMP8:%.*]] = add i32 [[TMP6]], 0 +; CHECK-NEXT: [[TMP9:%.*]] = add i32 [[TMP8]], 0 +; CHECK-NEXT: [[TMP10:%.*]] = getelementptr i32, ptr addrspace(7) [[PTR:%.*]], i32 [[TMP9]] +; CHECK-NEXT: [[TMP11:%.*]] = extractelement <4 x i32> [[TMP7]], i64 0 +; CHECK-NEXT: store i32 [[TMP11]], ptr addrspace(7) [[TMP10]], align 4 +; CHECK-NEXT: [[TMP12:%.*]] = add i32 [[TMP6]], 640 +; CHECK-NEXT: [[TMP13:%.*]] = add i32 [[TMP12]], 0 +; CHECK-NEXT: [[TMP14:%.*]] = getelementptr i32, ptr addrspace(7) [[PTR]], i32 [[TMP13]] +; CHECK-NEXT: [[TMP15:%.*]] = extractelement <4 x i32> [[TMP7]], i64 1 +; CHECK-NEXT: store i32 [[TMP15]], ptr addrspace(7) [[TMP14]], align 4 +; CHECK-NEXT: [[TMP16:%.*]] = add i32 [[TMP6]], 1280 +; CHECK-NEXT: [[TMP17:%.*]] = add i32 [[TMP16]], 0 +; CHECK-NEXT: [[TMP18:%.*]] = getelementptr i32, ptr addrspace(7) [[PTR]], i32 [[TMP17]] +; CHECK-NEXT: [[TMP19:%.*]] = extractelement <4 x i32> [[TMP7]], i64 2 +; CHECK-NEXT: store i32 [[TMP19]], ptr addrspace(7) [[TMP18]], align 4 +; CHECK-NEXT: [[TMP20:%.*]] = add i32 [[TMP6]], 1920 +; CHECK-NEXT: [[TMP21:%.*]] = add i32 [[TMP20]], 0 +; CHECK-NEXT: [[TMP22:%.*]] = getelementptr i32, ptr addrspace(7) [[PTR]], i32 [[TMP21]] +; CHECK-NEXT: [[TMP23:%.*]] = extractelement <4 x i32> [[TMP7]], i64 3 +; CHECK-NEXT: store i32 [[TMP23]], ptr addrspace(7) [[TMP22]], align 4 +; CHECK-NEXT: ret void +; + call void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8i32(ptr addrspace(7) %ptr, i32 640, i1 false, i32 5, i32 1, i32 0, <8 x i32> %a) + ret void +} + +declare void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8f32([4 x float] addrspace(7)*, i32, i1, i32, i32, i32, <8 x float>) +declare void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8i32([4 x float] addrspace(7)*, i32, i1, i32, i32, i32, <8 x i32>) + +!llpc.compute.mode = !{!0} +!lgc.client = !{!1} +!lgc.options = !{!2} +!lgc.options.CS = !{!3} + +!0 = !{i32 128, i32 2, i32 1} +!1 = !{!"Vulkan"} +!2 = !{i32 -2108299168, i32 -1199997545, i32 1667044824, i32 -422575072, i32 1, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 1, i32 0, i32 0, i32 -1} +!3 = !{i32 219437737, i32 -1317595285, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 64, i32 64, i32 0, i32 0, i32 3, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 20, i32 1800} +!8 = !{i32 5} +!9 = !{i32 7} diff --git a/lgc/tool/lgc/CMakeLists.txt b/lgc/tool/lgc/CMakeLists.txt index 41a610baa7..ee9d05baad 100644 --- a/lgc/tool/lgc/CMakeLists.txt +++ b/lgc/tool/lgc/CMakeLists.txt @@ -45,7 +45,7 @@ add_llvm_tool(lgc # lgc is linked in separately to account for both static and dynamic library # builds. -llvm_map_components_to_libnames(extra_llvm_libs lgc Continuations) +llvm_map_components_to_libnames(extra_llvm_libs lgc CompilerUtils Continuations) target_link_libraries(lgc PRIVATE ${extra_llvm_libs}) target_compile_definitions(lgc PRIVATE ${TARGET_ARCHITECTURE_ENDIANESS}ENDIAN_CPU) diff --git a/lgc/util/CpsStackLowering.cpp b/lgc/util/CpsStackLowering.cpp deleted file mode 100644 index 7daac65800..0000000000 --- a/lgc/util/CpsStackLowering.cpp +++ /dev/null @@ -1,240 +0,0 @@ -/* - *********************************************************************************************************************** - * - * Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - **********************************************************************************************************************/ - -#include "lgc/util/CpsStackLowering.h" -#include "lgc/util/BuilderBase.h" -#include "llvm-dialects/Dialect/Visitor.h" - -using namespace llvm; -using namespace lgc; - -LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD(CpsStackLowering, m_typeLowering) -// ===================================================================================================================== -// Type lowering rule that lowers cps stack pointer type to corresponding backend pointer type. -// -// @param typeLowering : the calling TypeLowering object -// @param type : the type to be converted -static SmallVector convertCpsStackPointer(TypeLowering &typeLowering, Type *type) { - SmallVector types; - - if (auto *pointerType = dyn_cast(type)) { - if (pointerType->getAddressSpace() == cps::stackAddrSpace) - types.push_back(PointerType::get(type->getContext(), getLoweredCpsStackAddrSpace())); - } - - return types; -} - -// ===================================================================================================================== -// Lower continuation stack operations in the function -// -// @param function : the function to be processed -// @param cpsStorage : the alloca used for the holding the latest continuation stack pointer -void CpsStackLowering::lowerCpsStackOps(Function &function, Value *cpsStorage) { - m_module = function.getParent(); - m_cpsStackAlloca = cpsStorage; - m_typeLowering.addRule(&convertCpsStackPointer); - auto *newFunc = &function; - if (cps::isCpsFunction(function)) - newFunc = m_typeLowering.lowerFunctionArguments(function); - - static const auto visitor = llvm_dialects::VisitorBuilder() - .nest(&TypeLowering::registerVisitors) - .add(&CpsStackLowering::visitCpsAlloc) - .add(&CpsStackLowering::visitCpsFree) - .add(&CpsStackLowering::visitCpsPeek) - .add(&CpsStackLowering::visitSetVsp) - .add(&CpsStackLowering::visitGetVsp) - .add(&CpsStackLowering::visitGetElementPtr) - .add(&CpsStackLowering::visitPtrToIntInst) - .add(&CpsStackLowering::visitIntToPtrInst) - .add(&CpsStackLowering::visitLoad) - .add(&CpsStackLowering::visitStore) - .build(); - visitor.visit(*this, *newFunc); - m_typeLowering.finishPhis(); - m_typeLowering.finishCleanup(); -} - -// ===================================================================================================================== -// Lower getelementptr instruction -// -// @param function : the instruction -void CpsStackLowering::visitGetElementPtr(GetElementPtrInst &getElemPtrInst) { - if (getElemPtrInst.getAddressSpace() != cps::stackAddrSpace) - return; - - auto values = m_typeLowering.getValue(getElemPtrInst.getPointerOperand()); - IRBuilder<> builder(&getElemPtrInst); - - SmallVector indices(getElemPtrInst.idx_begin(), getElemPtrInst.idx_end()); - - Value *newGetElemPtr = nullptr; - auto getElemPtrPtr = values[0]; - auto getElemPtrEltTy = getElemPtrInst.getSourceElementType(); - - if (getElemPtrInst.isInBounds()) - newGetElemPtr = builder.CreateInBoundsGEP(getElemPtrEltTy, getElemPtrPtr, indices); - else - newGetElemPtr = builder.CreateGEP(getElemPtrEltTy, getElemPtrPtr, indices); - - cast(newGetElemPtr)->copyMetadata(getElemPtrInst); - - m_typeLowering.replaceInstruction(&getElemPtrInst, {newGetElemPtr}); -} - -// ===================================================================================================================== -// Lower load instruction -// -// @param function : the instruction -void CpsStackLowering::visitLoad(LoadInst &load) { - if (load.getPointerAddressSpace() != cps::stackAddrSpace) - return; - - auto values = m_typeLowering.getValue(load.getPointerOperand()); - load.replaceUsesOfWith(load.getPointerOperand(), values[0]); -} - -// ===================================================================================================================== -// Lower store instruction -// -// @param function : the instruction -void CpsStackLowering::visitStore(llvm::StoreInst &store) { - if (store.getPointerAddressSpace() != cps::stackAddrSpace) - return; - - auto values = m_typeLowering.getValue(store.getPointerOperand()); - store.replaceUsesOfWith(store.getPointerOperand(), values[0]); -} - -// ===================================================================================================================== -// Lower ptrtoint instruction -// -// @param function : the instruction -void CpsStackLowering::visitPtrToIntInst(llvm::PtrToIntInst &ptr2Int) { - if (ptr2Int.getPointerAddressSpace() != cps::stackAddrSpace) - return; - - auto values = m_typeLowering.getValue(ptr2Int.getOperand(0)); - ptr2Int.replaceUsesOfWith(ptr2Int.getOperand(0), values[0]); -} - -// ===================================================================================================================== -// Lower inttoptr instruction -// -// @param function : the instruction -void CpsStackLowering::visitIntToPtrInst(llvm::IntToPtrInst &int2Ptr) { - if (int2Ptr.getAddressSpace() != cps::stackAddrSpace) - return; - - IRBuilder<> builder(&int2Ptr); - auto *newPtr = builder.CreateIntToPtr(int2Ptr.getOperand(0), - PointerType::get(builder.getContext(), getLoweredCpsStackAddrSpace())); - m_typeLowering.replaceInstruction(&int2Ptr, newPtr); -} - -// ===================================================================================================================== -// Lower lgc.cps.alloc instruction -// -// @param function : the instruction -void CpsStackLowering::visitCpsAlloc(cps::AllocOp &alloc) { - IRBuilder<> builder(&alloc); - - Value *size = alloc.getSize(); - const DataLayout &layout = m_module->getDataLayout(); - Value *vsp = builder.CreateAlignedLoad(builder.getPtrTy(getLoweredCpsStackAddrSpace()), m_cpsStackAlloca, - Align(getLoweredCpsStackPointerSize(layout))); - unsigned alignedSize = alignTo(cast(size)->getZExtValue(), continuationStackAlignment); - m_stackSizeInBytes += alignedSize; - - // update stack pointer - Value *ptr = builder.CreateConstGEP1_32(builder.getInt8Ty(), vsp, alignedSize); - builder.CreateAlignedStore(ptr, m_cpsStackAlloca, Align(getLoweredCpsStackPointerSize(layout))); - - m_typeLowering.replaceInstruction(&alloc, {vsp}); -} - -// ===================================================================================================================== -// Lower lgc.cps.free instruction -// -// @param function : the instruction -void CpsStackLowering::visitCpsFree(cps::FreeOp &freeOp) { - IRBuilder<> builder(&freeOp); - const DataLayout &layout = m_module->getDataLayout(); - - Value *vsp = builder.CreateAlignedLoad(builder.getPtrTy(getLoweredCpsStackAddrSpace()), m_cpsStackAlloca, - Align(getLoweredCpsStackPointerSize(layout))); - Value *size = freeOp.getSize(); - unsigned alignedSize = alignTo(cast(size)->getZExtValue(), continuationStackAlignment); - Value *ptr = builder.CreateConstGEP1_32(builder.getInt8Ty(), vsp, -alignedSize); - // Assuming continuation stack grows upward. - builder.CreateAlignedStore(ptr, m_cpsStackAlloca, Align(getLoweredCpsStackPointerSize(layout))); - m_typeLowering.replaceInstruction(&freeOp, {}); -} - -// ===================================================================================================================== -// Lower lgc.cps.peek instruction -// -// @param function : the instruction -void CpsStackLowering::visitCpsPeek(cps::PeekOp &peekOp) { - IRBuilder<> builder(&peekOp); - const DataLayout &layout = m_module->getDataLayout(); - - auto *ptr = builder.CreateAlignedLoad(builder.getPtrTy(getLoweredCpsStackAddrSpace()), m_cpsStackAlloca, - Align(getLoweredCpsStackPointerSize(layout))); - auto *size = peekOp.getSize(); - unsigned immSize = cast(size)->getZExtValue(); - immSize = alignTo(immSize, continuationStackAlignment); - // Assuming continuation stack grows upward. - auto *result = builder.CreateGEP(builder.getInt8Ty(), ptr, {builder.getInt32(-immSize)}); - m_typeLowering.replaceInstruction(&peekOp, {result}); -} - -// ===================================================================================================================== -// Lower lgc.cps.set.vsp instruction -// -// @param function : the instruction -void CpsStackLowering::visitSetVsp(cps::SetVspOp &setVsp) { - IRBuilder<> builder(&setVsp); - const DataLayout &layout = m_module->getDataLayout(); - - auto *ptr = setVsp.getPtr(); - auto converted = m_typeLowering.getValue(ptr); - builder.CreateAlignedStore(converted[0], m_cpsStackAlloca, Align(getLoweredCpsStackPointerSize(layout))); - m_typeLowering.replaceInstruction(&setVsp, {}); -} - -// ===================================================================================================================== -// Lower lgc.cps.get.vsp instruction -// -// @param function : the instruction -void CpsStackLowering::visitGetVsp(cps::GetVspOp &getVsp) { - IRBuilder<> builder(&getVsp); - const DataLayout &layout = m_module->getDataLayout(); - - auto *ptr = builder.CreateAlignedLoad(builder.getPtrTy(getLoweredCpsStackAddrSpace()), m_cpsStackAlloca, - Align(getLoweredCpsStackPointerSize(layout))); - m_typeLowering.replaceInstruction(&getVsp, {ptr}); -} diff --git a/lgc/util/Internal.cpp b/lgc/util/Internal.cpp index edc85bf436..bb17000f9d 100644 --- a/lgc/util/Internal.cpp +++ b/lgc/util/Internal.cpp @@ -234,16 +234,4 @@ Type *getVgprTy(Type *ty) { return ty; } -Function *mutateFunctionArguments(Function &fn, Type *retTy, const ArrayRef argTys, AttributeList attributes) { - FunctionType *newFnTy = FunctionType::get(retTy, argTys, false); - auto *newFn = Function::Create(newFnTy, fn.getLinkage()); - newFn->copyAttributesFrom(&fn); - newFn->copyMetadata(&fn, 0); - newFn->takeName(&fn); - newFn->setAttributes(attributes); - newFn->splice(newFn->begin(), &fn); - fn.getParent()->getFunctionList().insertAfter(fn.getIterator(), newFn); - return newFn; -} - } // namespace lgc diff --git a/llpc/CMakeLists.txt b/llpc/CMakeLists.txt index aba14388b6..72331a2fb8 100644 --- a/llpc/CMakeLists.txt +++ b/llpc/CMakeLists.txt @@ -55,11 +55,8 @@ if(ICD_BUILD_LLPC) set(LLVM_INCLUDE_DOCS OFF CACHE BOOL Force) set(LLVM_INCLUDE_EXAMPLES OFF CACHE BOOL Force) set(LLVM_INCLUDE_GO_TESTS OFF CACHE BOOL Force) - if(LLPC_BUILD_TESTS) - set(LLVM_INCLUDE_TESTS ON CACHE BOOL Force) - else() - set(LLVM_INCLUDE_TESTS OFF CACHE BOOL Force) - endif() + set(CONTINUATIONS_BUILD_TESTS ${LLPC_BUILD_TESTS}) + set(LLVM_INCLUDE_TESTS ${LLPC_BUILD_TESTS} CACHE BOOL Force) set(LLVM_INCLUDE_TOOLS ON CACHE BOOL Force) set(LLVM_INCLUDE_UTILS ON CACHE BOOL Force) set(LLVM_ENABLE_TERMINFO OFF CACHE BOOL Force) @@ -228,6 +225,7 @@ if(ICD_BUILD_LLPC) lower/LowerGpuRt.cpp lower/llpcSpirvLowerInternalLibraryIntrinsicUtil.cpp lower/LowerGLCompatibility.cpp + lower/llpcSpirvLowerCooperativeMatrix.cpp ) # llpc/translator diff --git a/llpc/context/llpcCompiler.cpp b/llpc/context/llpcCompiler.cpp index aca56cdaf3..5ee3fb977b 100644 --- a/llpc/context/llpcCompiler.cpp +++ b/llpc/context/llpcCompiler.cpp @@ -514,6 +514,49 @@ void Compiler::Destroy() { delete this; } +// ===================================================================================================================== +// Merge location and binding value, and replace the binding decoration in spirv binary. +// +// @param codeBuffer : Spirv binary +// @param imageSymbolInfo : Image symbol infos +static void mergeSpirvLocationAndBinding(llvm::MutableArrayRef codeBuffer, + std::vector &imageSymbolInfo) { + constexpr unsigned wordSize = sizeof(unsigned); + + unsigned *code = codeBuffer.data(); + unsigned *end = code + codeBuffer.size(); + unsigned *codePos = code + sizeof(SpirvHeader) / wordSize; + + while (codePos < end) { + unsigned opCode = (codePos[0] & OpCodeMask); + unsigned wordCount = (codePos[0] >> WordCountShift); + + switch (opCode) { + case OpDecorate: { + auto decoration = static_cast(codePos[2]); + + if (decoration == DecorationBinding) { + uint32_t varId = codePos[1]; + uint32_t binding = codePos[3]; + uint32_t location = 0; + for (auto it = imageSymbolInfo.begin(); it != imageSymbolInfo.end(); ++it) { + if (it->spvId == varId && it->binding == binding) { + location = it->location; + it->mergedLocationBinding = true; + } + } + uint32_t locationBinding = location << 16 | binding; + codePos[3] = locationBinding; + } + } break; + default: + break; + } + + codePos += wordCount; + } +} + // ===================================================================================================================== // Builds shader module from the specified info. // @@ -563,6 +606,9 @@ Result Compiler::BuildShaderModule(const ShaderModuleBuildInfo *shaderInfo, Shad allocSize += imageSymbolInfo.size() * sizeof(ResourceNodeData); allocSize += atomicCounterSymbolInfo.size() * sizeof(ResourceNodeData); allocSize += defaultUniformSymbolInfo.size() * sizeof(ResourceNodeData); + + if (imageSymbolInfo.size() && shaderInfo->options.mergeLocationAndBinding) + mergeSpirvLocationAndBinding(codeBuffer, imageSymbolInfo); } uint8_t *allocBuf = @@ -641,12 +687,14 @@ static bool getSymbolInfoFromSpvVariable(const SPIRVVariable *spvVar, ResourceNo uint32_t arraySize = 1; SPIRVWord location = 0; SPIRVWord binding = 0; + SPIRVWord varId = 0; BasicType basicType = BasicType::Unknown; SPIRVWord builtIn = false; bool isBuiltIn = spvVar->hasDecorate(DecorationBuiltIn, 0, &builtIn); spvVar->hasDecorate(DecorationLocation, 0, &location); spvVar->hasDecorate(DecorationBinding, 0, &binding); + varId = spvVar->getId(); SPIRVType *varElemTy = spvVar->getType()->getPointerElementType(); while (varElemTy->isTypeArray()) { @@ -707,6 +755,7 @@ static bool getSymbolInfoFromSpvVariable(const SPIRVVariable *spvVar, ResourceNo symbolInfo->location = location; symbolInfo->binding = binding; symbolInfo->basicType = basicType; + symbolInfo->spvId = varId; return isBuiltIn; } diff --git a/llpc/context/llpcDialect.h b/llpc/context/llpcDialect.h index e6b6f7fa4f..a03b39e54a 100644 --- a/llpc/context/llpcDialect.h +++ b/llpc/context/llpcDialect.h @@ -30,4 +30,8 @@ */ #pragma once -namespace LlpcName {} // namespace LlpcName +namespace LlpcName { + +const static char SpirvCooperativeMatrixProxy[] = "spirv.cooperative.matrix.proxy"; + +} // namespace LlpcName diff --git a/llpc/context/llpcGraphicsContext.cpp b/llpc/context/llpcGraphicsContext.cpp index 62f88e550c..866d2b0e82 100644 --- a/llpc/context/llpcGraphicsContext.cpp +++ b/llpc/context/llpcGraphicsContext.cpp @@ -367,6 +367,7 @@ void GraphicsContext::setVertexInputDescriptions(Pipeline *pipeline, Util::Metro break; case VK_VERTEX_INPUT_RATE_INSTANCE: bindings[idx].inputRate = VertexInputRateInstance; + bindings[idx].divisor = 1; // Set default divisor break; default: llvm_unreachable("Should never be called!"); @@ -380,7 +381,7 @@ void GraphicsContext::setVertexInputDescriptions(Pipeline *pipeline, Util::Metro for (unsigned i = 0; i < vertexDivisor->vertexBindingDivisorCount; ++i) { auto divisor = &vertexDivisor->pVertexBindingDivisors[i]; if (divisor->binding <= bindings.size()) - bindings[divisor->binding].inputRate = divisor->divisor; + bindings[divisor->binding].divisor = divisor->divisor; } } @@ -409,6 +410,7 @@ void GraphicsContext::setVertexInputDescriptions(Pipeline *pipeline, Util::Metro dfmt, nfmt, binding->inputRate, + binding->divisor, }); } } diff --git a/llpc/context/llpcPipelineContext.cpp b/llpc/context/llpcPipelineContext.cpp index bfadae17c4..0f0e508baf 100644 --- a/llpc/context/llpcPipelineContext.cpp +++ b/llpc/context/llpcPipelineContext.cpp @@ -340,6 +340,7 @@ Options PipelineContext::computePipelineOptions() const { options.disableSampleMask = getPipelineOptions()->disableSampleMask; options.disableTruncCoordForGather = getPipelineOptions()->disableTruncCoordForGather; options.enablePrimGeneratedQuery = getPipelineOptions()->enablePrimGeneratedQuery; + options.enableFragColor = getPipelineOptions()->enableFragColor; return options; } @@ -864,13 +865,13 @@ std::pair PipelineContext::mapVkFormat(VkFormat for BOTH_FORMAT_ENTRY(VK_FORMAT_R16G16_UINT, BufDataFormat16_16, BufNumFormatUint), BOTH_FORMAT_ENTRY(VK_FORMAT_R16G16_SINT, BufDataFormat16_16, BufNumFormatSint), BOTH_FORMAT_ENTRY(VK_FORMAT_R16G16_SFLOAT, BufDataFormat16_16, BufNumFormatFloat), - INVALID_FORMAT_ENTRY(VK_FORMAT_R16G16B16_UNORM), - INVALID_FORMAT_ENTRY(VK_FORMAT_R16G16B16_SNORM), - INVALID_FORMAT_ENTRY(VK_FORMAT_R16G16B16_USCALED), - INVALID_FORMAT_ENTRY(VK_FORMAT_R16G16B16_SSCALED), - INVALID_FORMAT_ENTRY(VK_FORMAT_R16G16B16_UINT), - INVALID_FORMAT_ENTRY(VK_FORMAT_R16G16B16_SINT), - INVALID_FORMAT_ENTRY(VK_FORMAT_R16G16B16_SFLOAT), + BOTH_FORMAT_ENTRY(VK_FORMAT_R16G16B16_UNORM, BufDataFormat16_16_16, BufNumFormatUnorm), + BOTH_FORMAT_ENTRY(VK_FORMAT_R16G16B16_SNORM, BufDataFormat16_16_16, BufNumFormatSnorm), + BOTH_FORMAT_ENTRY(VK_FORMAT_R16G16B16_USCALED, BufDataFormat16_16_16, BufNumFormatUscaled), + BOTH_FORMAT_ENTRY(VK_FORMAT_R16G16B16_SSCALED, BufDataFormat16_16_16, BufNumFormatSscaled), + BOTH_FORMAT_ENTRY(VK_FORMAT_R16G16B16_UINT, BufDataFormat16_16_16, BufNumFormatUint), + BOTH_FORMAT_ENTRY(VK_FORMAT_R16G16B16_SINT, BufDataFormat16_16_16, BufNumFormatSint), + BOTH_FORMAT_ENTRY(VK_FORMAT_R16G16B16_SFLOAT, BufDataFormat16_16_16, BufNumFormatFloat), BOTH_FORMAT_ENTRY(VK_FORMAT_R16G16B16A16_UNORM, BufDataFormat16_16_16_16, BufNumFormatUnorm), BOTH_FORMAT_ENTRY(VK_FORMAT_R16G16B16A16_SNORM, BufDataFormat16_16_16_16, BufNumFormatSnorm), BOTH_FORMAT_ENTRY(VK_FORMAT_R16G16B16A16_USCALED, BufDataFormat16_16_16_16, BufNumFormatUscaled), @@ -1013,6 +1014,8 @@ std::pair PipelineContext::mapVkFormat(VkFormat for #endif COLOR_FORMAT_ENTRY_EXT(VK_FORMAT_A4R4G4B4_UNORM_PACK16_EXT, BufDataFormat4_4_4_4, BufNumFormatUnorm), COLOR_FORMAT_ENTRY_EXT(VK_FORMAT_A4B4G4R4_UNORM_PACK16_EXT, BufDataFormat4_4_4_4, BufNumFormatUnorm), + COLOR_FORMAT_ENTRY_EXT(VK_FORMAT_A1B5G5R5_UNORM_PACK16, BufDataFormat1_5_6_5, BufNumFormatUnorm), + COLOR_FORMAT_ENTRY_EXT(VK_FORMAT_A8_UNORM_KHR, BufDataFormat8_A, BufNumFormatUnorm), /// Currently OGL-only : Internal spv ext vertex attribute format - begin EXT_VERTEX_FORMAT_ENTRY(VK_FORMAT_EXT_R32_UNORM, BufDataFormat32, BufNumFormatUnorm), EXT_VERTEX_FORMAT_ENTRY(VK_FORMAT_EXT_R32G32_UNORM, BufDataFormat32_32, BufNumFormatUnorm), diff --git a/llpc/context/llpcRayTracingContext.cpp b/llpc/context/llpcRayTracingContext.cpp index b84747d643..42b28539bd 100644 --- a/llpc/context/llpcRayTracingContext.cpp +++ b/llpc/context/llpcRayTracingContext.cpp @@ -281,6 +281,8 @@ lgc::Options RayTracingContext::computePipelineOptions() const { options.rtIndirectMode = lgc::RayTracingIndirectMode::ContinuationsContinufy; else if (m_pipelineInfo->mode == Vkgc::LlpcRaytracingMode::Continuations) options.rtIndirectMode = lgc::RayTracingIndirectMode::Continuations; + + options.cpsFlags = m_pipelineInfo->cpsFlags; #endif return options; diff --git a/llpc/include/llpc.h b/llpc/include/llpc.h index 3a7f5bfa66..358f1e5e09 100644 --- a/llpc/include/llpc.h +++ b/llpc/include/llpc.h @@ -103,6 +103,7 @@ static const char VkIcdName[] = "amdvlk"; /// Represents per shader module options. struct ShaderModuleOptions { PipelineOptions pipelineOptions; ///< Pipeline options related with this shader module + bool mergeLocationAndBinding; }; /// Represents info to build a shader module. diff --git a/llpc/lower/LowerGLCompatibility.cpp b/llpc/lower/LowerGLCompatibility.cpp index cf073a73c9..badffab8d9 100644 --- a/llpc/lower/LowerGLCompatibility.cpp +++ b/llpc/lower/LowerGLCompatibility.cpp @@ -282,10 +282,16 @@ void LowerGLCompatibility::createClipPlane() { auto locationFound = getUniformConstantEntryByLocation(m_context, m_shaderStage, Vkgc::GlCompatibilityUniformLocation::ClipPlane); auto clipPlaneBaseOffset = locationFound != nullptr ? locationFound->offset : 0; + assert(m_shaderStage != ShaderStageTask && m_shaderStage != ShaderStageMesh); + unsigned constBufferBinding = + Vkgc::ConstantBuffer0Binding + static_cast(m_context->getPipelineContext()) + ->getPipelineShaderInfo(m_shaderStage) + ->options.constantBufferBindingOffset; + std::vector mDs; auto int32Ty = Type::getInt32Ty(*m_context); mDs.push_back(ConstantAsMetadata::get(ConstantInt::get(int32Ty, Vkgc::InternalDescriptorSetId))); - mDs.push_back(ConstantAsMetadata::get(ConstantInt::get(int32Ty, Vkgc::ConstantBuffer0Binding))); + mDs.push_back(ConstantAsMetadata::get(ConstantInt::get(int32Ty, constBufferBinding))); mDs.push_back(ConstantAsMetadata::get(ConstantInt::get(int32Ty, clipPlaneBaseOffset))); mDs.push_back(ConstantAsMetadata::get(ConstantInt::get(int32Ty, Vkgc::GlCompatibilityUniformLocation::ClipPlane))); auto mdNode = MDNode::get(*m_context, mDs); diff --git a/llpc/lower/PassRegistry.inc b/llpc/lower/PassRegistry.inc index c85ef22d3c..c0afde0d2a 100644 --- a/llpc/lower/PassRegistry.inc +++ b/llpc/lower/PassRegistry.inc @@ -37,6 +37,7 @@ LLPC_MODULE_PASS("llpc-spirv-lower-gl-compatibility", LowerGLCompatibility) LLPC_MODULE_PASS("llpc-spirv-lower-access-chain", SpirvLowerAccessChain) LLPC_MODULE_PASS("llpc-spirv-lower-cfg-merges", SpirvLowerCfgMerges) LLPC_MODULE_PASS("llpc-spirv-lower-const-immediate-store", SpirvLowerConstImmediateStore) +LLPC_MODULE_PASS("llpc-spirv-lower-cooperative-matrix", SpirvLowerCooperativeMatrix) LLPC_MODULE_PASS("llpc-spirv-lower-inst-meta-remove", SpirvLowerInstMetaRemove) LLPC_MODULE_PASS("llpc-spirv-lower-terminator", SpirvLowerTerminator) LLPC_MODULE_PASS("llpc-spirv-lower-translator", SpirvLowerTranslator) diff --git a/llpc/lower/llpcSpirvLower.cpp b/llpc/lower/llpcSpirvLower.cpp index 52ba78f9c0..1d225ff8b7 100644 --- a/llpc/lower/llpcSpirvLower.cpp +++ b/llpc/lower/llpcSpirvLower.cpp @@ -35,6 +35,7 @@ #include "llpcSpirvLowerAccessChain.h" #include "llpcSpirvLowerCfgMerges.h" #include "llpcSpirvLowerConstImmediateStore.h" +#include "llpcSpirvLowerCooperativeMatrix.h" #include "llpcSpirvLowerGlobal.h" #include "llpcSpirvLowerInstMetaRemove.h" #include "llpcSpirvLowerMath.h" @@ -201,6 +202,9 @@ void SpirvLower::addPasses(Context *context, ShaderStage stage, lgc::PassManager // Lower SPIR-V terminators passMgr.addPass(SpirvLowerTerminator()); + // Lower spirv.cooperative.matrix.proxy to LGC operations. Should run before SROA. + passMgr.addPass(SpirvLowerCooperativeMatrix()); + // Lower Glsl compatibility variables and operations passMgr.addPass(LowerGLCompatibility()); diff --git a/llpc/lower/llpcSpirvLowerCooperativeMatrix.cpp b/llpc/lower/llpcSpirvLowerCooperativeMatrix.cpp new file mode 100644 index 0000000000..6280a79364 --- /dev/null +++ b/llpc/lower/llpcSpirvLowerCooperativeMatrix.cpp @@ -0,0 +1,169 @@ +/* + *********************************************************************************************************************** + * + * Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + **********************************************************************************************************************/ +/** + *********************************************************************************************************************** + * @file llpcSpirvLowerCooperativeMatrix.cpp + * @brief LLPC source file: pass that lower SPIR-V specific cooperative matrix operations + * + * This currently only handles spirv.cooperative.matrix.proxy, which is used to proxy pointers to cooperative matrix + * values for component load/store. + *********************************************************************************************************************** + */ + +#include "llpcSpirvLowerCooperativeMatrix.h" +#include "llpcDialect.h" +#include "lgc/BuilderCommon.h" +#include "llvm/IR/Instructions.h" + +#define DEBUG_TYPE "llpc-spirv-lower-cooperative-matrix" + +using namespace llvm; +using namespace lgc; +using namespace Llpc; + +namespace { + +// ===================================================================================================================== +// Implementation class of the pass, hidden from external access. +class LowerCooperativeMatrix { +public: + LowerCooperativeMatrix(Module &module) : m_module(module), m_builder(module.getContext()) {} + + PreservedAnalyses run(); + +private: + void visitProxy(CallInst &call); + void visitPointerUsers(Value *ptr, BuilderCommon::CooperativeMatrixElementType elemTypeEnum, + BuilderCommon::CooperativeMatrixLayout layout, Type *elemType, Value *matrixPtr, Value *index); + + Module &m_module; + BuilderCommon m_builder; + SmallVector m_toDelete; +}; + +} // anonymous namespace + +// ===================================================================================================================== +// Run the lowering implementation. +PreservedAnalyses LowerCooperativeMatrix::run() { + bool changed = false; + + for (Function &function : m_module.functions()) { + if (function.isDeclaration() && function.getName().startswith(LlpcName::SpirvCooperativeMatrixProxy)) { + for (User *user : function.users()) { + if (auto *call = dyn_cast(user)) { + assert(call->getCalledOperand() == &function); + visitProxy(*call); + changed = true; + } + } + } + } + + for (Instruction *inst : reverse(m_toDelete)) + inst->eraseFromParent(); + + return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} + +// ===================================================================================================================== +// Handle one call to spirv.cooperative.matrix.proxy. +// +// @param call : the call instruction +// @returns true if a change was made +void LowerCooperativeMatrix::visitProxy(CallInst &call) { + Value *ptr = call.getArgOperand(0); + auto elemTypeEnum = + (BuilderCommon::CooperativeMatrixElementType)(cast(call.getArgOperand(1))->getZExtValue()); + Type *elemType = m_builder.transCooperativeMatrixElementType(elemTypeEnum); + auto layout = (BuilderCommon::CooperativeMatrixLayout)(cast(call.getArgOperand(2))->getZExtValue()); + + m_toDelete.push_back(&call); + visitPointerUsers(&call, elemTypeEnum, layout, elemType, ptr, m_builder.getInt32(0)); +} + +// ===================================================================================================================== +// Handle all users of a pointer defined directly or indirectly via spirv.cooperative.matrix.proxy. +// +// @param ptr : the pointer whose users should be handled +// @param elemType : the matrix element type +// @param layout : the matrix layout +// @param matrixPtr : the pointer to the underlying proxied matrix +// @param index : the 32-bit index of the matrix that @p ptr points to +void LowerCooperativeMatrix::visitPointerUsers(Value *ptr, BuilderCommon::CooperativeMatrixElementType elemTypeEnum, + BuilderCommon::CooperativeMatrixLayout layout, Type *elemType, + Value *matrixPtr, Value *index) { + for (User *user : ptr->users()) { + Instruction *inst = cast(user); + m_builder.SetInsertPoint(inst); + + m_toDelete.push_back(inst); + + if (auto *load = dyn_cast(inst)) { + assert(load->getPointerOperand() == ptr); + assert(load->getType() == elemType); + + Type *matrixType = m_builder.getCooperativeMatrixTy(elemTypeEnum, layout); + Value *matrix = m_builder.CreateLoad(matrixType, matrixPtr); + Value *element = m_builder.CreateCooperativeMatrixExtract(matrix, index, elemTypeEnum, layout); + load->replaceAllUsesWith(element); + } else if (auto *store = dyn_cast(inst)) { + assert(store->getPointerOperand() == ptr); + assert(store->getValueOperand()->getType() == elemType); + + Type *matrixType = m_builder.getCooperativeMatrixTy(elemTypeEnum, layout); + Value *matrix = m_builder.CreateLoad(matrixType, matrixPtr); + matrix = m_builder.CreateCooperativeMatrixInsert(matrix, store->getValueOperand(), index, elemTypeEnum, layout); + m_builder.CreateStore(matrix, matrixPtr); + } else if (auto *gep = dyn_cast(inst)) { + assert(gep->getPointerOperand() == ptr); + assert(gep->getSourceElementType() == elemType); + assert(gep->getNumIndices() == 1); + + Value *gepIndex = gep->indices().begin()->get(); + gepIndex = m_builder.CreateZExtOrTrunc(gepIndex, m_builder.getInt32Ty()); + + bool baseIsZero = false; + if (auto *constIndex = dyn_cast(index)) + baseIsZero = constIndex->getZExtValue() == 0; + if (!baseIsZero) + gepIndex = m_builder.CreateAdd(index, gepIndex); + + visitPointerUsers(gep, elemTypeEnum, layout, elemType, matrixPtr, gepIndex); + } else { + llvm_unreachable("indirect users of spirv.cooperative.matrix.proxy pointer"); + } + } +} + +// ===================================================================================================================== +// Executes this SPIR-V lowering pass on the specified LLVM module. +// +// @param [in/out] module : LLVM module to be run on +// @param [in/out] analysisManager : Analysis manager to use for this transformation +PreservedAnalyses SpirvLowerCooperativeMatrix::run(Module &module, ModuleAnalysisManager &analysisManager) { + LowerCooperativeMatrix impl{module}; + return impl.run(); +} diff --git a/llpc/lower/llpcSpirvLowerCooperativeMatrix.h b/llpc/lower/llpcSpirvLowerCooperativeMatrix.h new file mode 100644 index 0000000000..36eb73a0b2 --- /dev/null +++ b/llpc/lower/llpcSpirvLowerCooperativeMatrix.h @@ -0,0 +1,46 @@ +/* + *********************************************************************************************************************** + * + * Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + **********************************************************************************************************************/ +/** + *********************************************************************************************************************** + * @file llpcSpirvLowerCooperativeMatrix.h + * @brief LLPC header file: lower SPIR-V specific cooperative matrix operations to LGC + *********************************************************************************************************************** + */ +#pragma once + +#include "llvm/IR/PassManager.h" + +namespace Llpc { + +// ===================================================================================================================== +// Pass that lower SPIR-V-specific cooperative matrix operations +class SpirvLowerCooperativeMatrix : public llvm::PassInfoMixin { +public: + llvm::PreservedAnalyses run(llvm::Module &module, llvm::ModuleAnalysisManager &analysisManager); + + static llvm::StringRef name() { return "spirv-lower-cooperative-matrix"; } +}; + +} // namespace Llpc diff --git a/llpc/lower/llpcSpirvLowerGlobal.cpp b/llpc/lower/llpcSpirvLowerGlobal.cpp index 443501e0eb..6939fb06d0 100644 --- a/llpc/lower/llpcSpirvLowerGlobal.cpp +++ b/llpc/lower/llpcSpirvLowerGlobal.cpp @@ -32,6 +32,7 @@ #include "SPIRVInternal.h" #include "llpcContext.h" #include "llpcDebug.h" +#include "llpcGraphicsContext.h" #include "llpcRayTracingContext.h" #include "llpcSpirvLowerUtil.h" #include "lgc/LgcDialect.h" @@ -1224,8 +1225,14 @@ Value *SpirvLowerGlobal::addCallInstForInOutImport(Type *inOutTy, unsigned addrS Vkgc::GlCompatibilityUniformLocation::FrameBufferSize); if (winSize) { offset = winSize->offset; + assert(m_shaderStage != Vkgc::ShaderStageTask && m_shaderStage != Vkgc::ShaderStageMesh); + unsigned constBufferBinding = + Vkgc::ConstantBuffer0Binding + static_cast(m_context->getPipelineContext()) + ->getPipelineShaderInfo(m_shaderStage) + ->options.constantBufferBindingOffset; + Value *bufferDesc = - m_builder->CreateLoadBufferDesc(Vkgc::InternalDescriptorSetId, Vkgc::ConstantBuffer0Binding, + m_builder->CreateLoadBufferDesc(Vkgc::InternalDescriptorSetId, constBufferBinding, m_builder->getInt32(0), lgc::Builder::BufferFlagNonConst); // Layout is {width, height}, so the offset of height is added sizeof(float). Value *winHeightPtr = diff --git a/llpc/lower/llpcSpirvLowerRayQuery.cpp b/llpc/lower/llpcSpirvLowerRayQuery.cpp index d555fce808..b572df0e70 100644 --- a/llpc/lower/llpcSpirvLowerRayQuery.cpp +++ b/llpc/lower/llpcSpirvLowerRayQuery.cpp @@ -1321,7 +1321,6 @@ Value *SpirvLowerRayQuery::createGetInstanceNodeAddr(Value *instNodePtr, Value * Value *BvhAddr = PoisonValue::get(FixedVectorType::get(Type::getInt32Ty(*m_context), 2)); BvhAddr = m_builder->CreateInsertElement(BvhAddr, BvhAddrLo, uint64_t(0)); BvhAddr = m_builder->CreateInsertElement(BvhAddr, BvhAddrHi, 1); - // Mask out the node offset auto nodeOffsetMask = m_builder->getInt32(0xFFFFFFF8u); // Shift left by 3 to make it 64B aligned address diff --git a/llpc/test/shaderdb/gfx11/cooperativeMatrix/array-of-matrices.comp b/llpc/test/shaderdb/gfx11/cooperativeMatrix/array-of-matrices.comp new file mode 100644 index 0000000000..97c5983db4 --- /dev/null +++ b/llpc/test/shaderdb/gfx11/cooperativeMatrix/array-of-matrices.comp @@ -0,0 +1,62 @@ +// NOTE: Assertions have been autogenerated by tool/update_llpc_test_checks.py +// RUN: amdllpc -o - -gfxip 11.0 -emit-lgc %s | FileCheck -check-prefixes=CHECK %s +// REQUIRES: do-not-run-me + +#version 450 core +#pragma use_vulkan_memory_model +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_explicit_arithmetic_types : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16: enable + +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout(set=0, binding=0, std430) buffer Buf { uvec4 x[]; } buf; + +layout(push_constant) uniform PushConstants { + int idx1; + int idx2; +}; + +#define ELT_SIZE 16 + +void main() { + coopmat A[2]; + coopMatLoad(A[0], buf.x, 0, ELT_SIZE / 8, 0); + coopMatLoad(A[1], buf.x, 32, ELT_SIZE / 8, 0); + + buf.x[0].x = uint(A[idx1][3]); + + coopMatStore(A[idx2], buf.x, 64, 4, 0); +} + +// CHECK-LABEL: @lgc.shader.CS.main( +// CHECK-LABEL: .entry: +// CHECK-NEXT: [[TMP0:%.*]] = call ptr addrspace(4) (...) @lgc.create.load.push.constants.ptr.p4() +// CHECK-NEXT: [[TMP1:%.*]] = call ptr addrspace(7) (...) @lgc.create.load.buffer.desc.p7(i64 0, i32 0, i32 0, i32 2) +// CHECK-NEXT: [[TMP2:%.*]] = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p7.i32.i1.i32.i32.i32(ptr addrspace(7) [[TMP1]], i32 32, i1 true, i32 1, i32 0, i32 0) #[[ATTR1:[0-9]+]] +// CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds <{ [4294967295 x [4 x i32]] }>, ptr addrspace(7) [[TMP1]], i32 0, i32 0, i32 32 +// CHECK-NEXT: [[TMP4:%.*]] = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p7.i32.i1.i32.i32.i32(ptr addrspace(7) [[TMP3]], i32 32, i1 true, i32 1, i32 0, i32 0) #[[ATTR1]] +// CHECK-NEXT: [[TMP5:%.*]] = load i32, ptr addrspace(4) [[TMP0]], align 4 +// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i32 [[TMP5]], 2 +// CHECK-NEXT: br i1 [[TMP6]], label %[[LABEL7:.*]], label %[[LABEL12:.*]] +// CHECK: [[LABEL7]]: +// CHECK-NEXT: [[TMP8:%.*]] = icmp eq i32 [[TMP5]], 1 +// CHECK-NEXT: [[TMP9:%.*]] = select i1 [[TMP8]], <8 x float> [[TMP4]], <8 x float> [[TMP2]] +// CHECK-NEXT: [[TMP10:%.*]] = call half @lgc.cooperative.matrix.extract.f16.v8f32.i32.i32.i32(<8 x float> [[TMP9]], i32 3, i32 1, i32 0) #[[ATTR3:[0-9]+]] +// CHECK-NEXT: [[TMP11:%.*]] = fptoui half [[TMP10]] to i32 +// CHECK-NEXT: br label %[[LABEL12]] +// CHECK: [[LABEL12]]: +// CHECK-NEXT: [[TMP13:%.*]] = phi i32 [ 0, [[DOTENTRY:%.*]] ], [ [[TMP11]], %[[LABEL7]] ] +// CHECK-NEXT: store i32 [[TMP13]], ptr addrspace(7) [[TMP1]], align 4 +// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds <{ i32, i32 }>, ptr addrspace(4) [[TMP0]], i64 0, i32 1 +// CHECK-NEXT: [[TMP15:%.*]] = load i32, ptr addrspace(4) [[TMP14]], align 4 +// CHECK-NEXT: [[TMP16:%.*]] = icmp ult i32 [[TMP15]], 2 +// CHECK-NEXT: [[TMP17:%.*]] = icmp eq i32 [[TMP15]], 1 +// CHECK-NEXT: [[TMP18:%.*]] = select i1 [[TMP17]], <8 x float> [[TMP4]], <8 x float> [[TMP2]] +// CHECK-NEXT: [[TMP19:%.*]] = select i1 [[TMP16]], <8 x float> [[TMP18]], <8 x float> zeroinitializer +// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds <{ [4294967295 x [4 x i32]] }>, ptr addrspace(7) [[TMP1]], i32 0, i32 0, i32 64 +// CHECK-NEXT: call void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8f32(ptr addrspace(7) [[TMP20]], i32 64, i1 true, i32 1, i32 0, i32 0, <8 x float> [[TMP19]]) #[[ATTR2:[0-9]+]] +// CHECK-NEXT: ret void +// diff --git a/llpc/test/shaderdb/gfx11/cooperativeMatrix/extract-insert.spvasm b/llpc/test/shaderdb/gfx11/cooperativeMatrix/extract-insert.spvasm new file mode 100644 index 0000000000..c5fbaac0cc --- /dev/null +++ b/llpc/test/shaderdb/gfx11/cooperativeMatrix/extract-insert.spvasm @@ -0,0 +1,147 @@ +; NOTE: Assertions have been autogenerated by tool/update_llpc_test_checks.py +; RUN: amdllpc -o - -gfxip 11.0 -emit-lgc %s | FileCheck -check-prefixes=CHECK %s + +; SPIR-V +; Version: 1.6 + +; This test is derived from a simple compute shader compiled with glslang. It was hand-edited to use +; OpCooperativeMatrixLengthKHR as an instruction instead of a specialization constant. + + OpCapability Shader + OpCapability Float16 + OpCapability VulkanMemoryModel + OpCapability CooperativeMatrixKHR + OpExtension "SPV_KHR_cooperative_matrix" + %2 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical Vulkan + OpEntryPoint GLCompute %main "main" %bufInA %bufInB %bufOut + OpExecutionModeId %main LocalSizeId %uint_32 %uint_1 %uint_1 + OpDecorate %_runtimearr_v4uint ArrayStride 16 + OpMemberDecorate %InputA 0 Offset 0 + OpDecorate %InputA Block + OpDecorate %bufInA DescriptorSet 0 + OpDecorate %bufInA Binding 0 + OpDecorate %_runtimearr_v4uint_0 ArrayStride 16 + OpMemberDecorate %InputB 0 Offset 0 + OpDecorate %InputB Block + OpDecorate %bufInB DescriptorSet 0 + OpDecorate %bufInB Binding 1 + OpDecorate %_runtimearr_v4uint_1 ArrayStride 16 + OpMemberDecorate %Output 0 Offset 0 + OpDecorate %Output Block + OpDecorate %bufOut DescriptorSet 0 + OpDecorate %bufOut Binding 2 + %void = OpTypeVoid + %4 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %uint_32 = OpConstant %uint 32 + %uint_1 = OpConstant %uint 1 + %half = OpTypeFloat 16 + %uint_3 = OpConstant %uint 3 + %uint_16 = OpConstant %uint 16 + %13 = OpTypeCooperativeMatrixKHR %half %uint_3 %uint_16 %uint_16 %uint_1 +%_ptr_Function_13 = OpTypePointer Function %13 + %v4uint = OpTypeVector %uint 4 +%_runtimearr_v4uint = OpTypeRuntimeArray %v4uint + %InputA = OpTypeStruct %_runtimearr_v4uint +%_ptr_StorageBuffer_InputA = OpTypePointer StorageBuffer %InputA + %bufInA = OpVariable %_ptr_StorageBuffer_InputA StorageBuffer + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %uint_0 = OpConstant %uint 0 +%_ptr_StorageBuffer_v4uint = OpTypePointer StorageBuffer %v4uint + %uint_2 = OpConstant %uint 2 + %bool = OpTypeBool + %false = OpConstantFalse %bool +%_runtimearr_v4uint_0 = OpTypeRuntimeArray %v4uint + %InputB = OpTypeStruct %_runtimearr_v4uint_0 +%_ptr_StorageBuffer_InputB = OpTypePointer StorageBuffer %InputB + %bufInB = OpVariable %_ptr_StorageBuffer_InputB StorageBuffer +%_ptr_Function_int = OpTypePointer Function %int +%_ptr_Function_half = OpTypePointer Function %half + %int_1 = OpConstant %int 1 +%_runtimearr_v4uint_1 = OpTypeRuntimeArray %v4uint + %Output = OpTypeStruct %_runtimearr_v4uint_1 +%_ptr_StorageBuffer_Output = OpTypePointer StorageBuffer %Output + %bufOut = OpVariable %_ptr_StorageBuffer_Output StorageBuffer + %uint_4 = OpConstant %uint 4 + %v3uint = OpTypeVector %uint 3 + %74 = OpConstantComposite %v3uint %uint_32 %uint_1 %uint_1 + %main = OpFunction %void None %4 + %6 = OpLabel + %tempArg = OpVariable %_ptr_Function_13 Function + %A = OpVariable %_ptr_Function_13 Function + %tempArg_0 = OpVariable %_ptr_Function_13 Function + %B = OpVariable %_ptr_Function_13 Function + %i = OpVariable %_ptr_Function_int Function + %C = OpVariable %_ptr_Function_13 Function + %25 = OpAccessChain %_ptr_StorageBuffer_v4uint %bufInA %int_0 %uint_0 + %29 = OpCooperativeMatrixLoadKHR %13 %25 %uint_1 %uint_2 None + OpStore %tempArg %29 + %31 = OpLoad %13 %tempArg + OpStore %A %31 + %37 = OpAccessChain %_ptr_StorageBuffer_v4uint %bufInB %int_0 %uint_0 + %38 = OpCooperativeMatrixLoadKHR %13 %37 %uint_1 %uint_2 None + OpStore %tempArg_0 %38 + %40 = OpLoad %13 %tempArg_0 + OpStore %B %40 + OpStore %i %int_0 + OpBranch %43 + %43 = OpLabel + OpLoopMerge %45 %46 Unroll + OpBranch %47 + %47 = OpLabel + %48 = OpLoad %int %i + %50 = OpCooperativeMatrixLengthKHR %uint %13 + %51 = OpSLessThan %bool %48 %50 + OpBranchConditional %51 %44 %45 + %44 = OpLabel + %53 = OpLoad %int %i + %54 = OpLoad %int %i + %56 = OpAccessChain %_ptr_Function_half %A %54 + %57 = OpLoad %half %56 + %58 = OpLoad %int %i + %59 = OpAccessChain %_ptr_Function_half %B %58 + %60 = OpLoad %half %59 + %61 = OpFMul %half %57 %60 + %62 = OpAccessChain %_ptr_Function_half %C %53 + OpStore %62 %61 + OpBranch %46 + %46 = OpLabel + %63 = OpLoad %int %i + %65 = OpIAdd %int %63 %int_1 + OpStore %i %65 + OpBranch %43 + %45 = OpLabel + %66 = OpLoad %13 %C + %71 = OpAccessChain %_ptr_StorageBuffer_v4uint %bufOut %int_0 %uint_0 + OpCooperativeMatrixStoreKHR %71 %66 %uint_1 %uint_4 None + OpReturn + OpFunctionEnd + +; CHECK-LABEL: @lgc.shader.CS.main( +; CHECK-LABEL: .entry: +; CHECK-NEXT: [[TMP0:%.*]] = call ptr addrspace(7) (...) @lgc.create.load.buffer.desc.p7(i64 0, i32 2, i32 0, i32 2) +; CHECK-NEXT: [[TMP1:%.*]] = call ptr addrspace(7) (...) @lgc.create.load.buffer.desc.p7(i64 0, i32 1, i32 0, i32 2) +; CHECK-NEXT: [[TMP2:%.*]] = call ptr addrspace(7) (...) @lgc.create.load.buffer.desc.p7(i64 0, i32 0, i32 0, i32 2) +; CHECK-NEXT: [[TMP3:%.*]] = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p7.i32.i1.i32.i32.i32(ptr addrspace(7) [[TMP2]], i32 32, i1 true, i32 1, i32 0, i32 0) #[[ATTR1:[0-9]+]] +; CHECK-NEXT: [[TMP4:%.*]] = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p7.i32.i1.i32.i32.i32(ptr addrspace(7) [[TMP1]], i32 32, i1 true, i32 1, i32 0, i32 0) #[[ATTR1]] +; CHECK-NEXT: br label [[TMP5:%.*]] +; CHECK: 5: +; CHECK-NEXT: [[DOT011:%.*]] = phi i32 [ 0, [[DOTENTRY:%.*]] ], [ [[TMP13:%.*]], [[TMP8:%.*]] ] +; CHECK-NEXT: [[DOT0:%.*]] = phi <8 x float> [ undef, [[DOTENTRY]] ], [ [[TMP12:%.*]], [[TMP8]] ] +; CHECK-NEXT: [[TMP6:%.*]] = call i32 @lgc.cooperative.matrix.length.i32.i32.i32(i32 1, i32 0) #[[ATTR2:[0-9]+]] +; CHECK-NEXT: [[TMP7:%.*]] = icmp slt i32 [[DOT011]], [[TMP6]] +; CHECK-NEXT: [[FR:%.*]] = freeze i1 [[TMP7]] +; CHECK-NEXT: br i1 [[FR]], label [[TMP8]], label [[TMP14:%.*]] +; CHECK: 8: +; CHECK-NEXT: [[TMP9:%.*]] = call half @lgc.cooperative.matrix.extract.f16.v8f32.i32.i32.i32(<8 x float> [[TMP3]], i32 [[DOT011]], i32 1, i32 0) #[[ATTR2]] +; CHECK-NEXT: [[TMP10:%.*]] = call half @lgc.cooperative.matrix.extract.f16.v8f32.i32.i32.i32(<8 x float> [[TMP4]], i32 [[DOT011]], i32 1, i32 0) #[[ATTR2]] +; CHECK-NEXT: [[TMP11:%.*]] = fmul reassoc nnan nsz arcp contract afn half [[TMP9]], [[TMP10]] +; CHECK-NEXT: [[TMP12]] = call <8 x float> @lgc.cooperative.matrix.insert.v8f32.v8f32.f16.i32.i32.i32(<8 x float> [[DOT0]], half [[TMP11]], i32 [[DOT011]], i32 1, i32 0) #[[ATTR2]] +; CHECK-NEXT: [[TMP13]] = add i32 [[DOT011]], 1 +; CHECK-NEXT: br label [[TMP5]], !llvm.loop [[LOOP8:![0-9]+]] +; CHECK: 14: +; CHECK-NEXT: call void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8f32(ptr addrspace(7) [[TMP0]], i32 64, i1 true, i32 1, i32 0, i32 0, <8 x float> [[DOT0]]) #[[ATTR3:[0-9]+]] +; CHECK-NEXT: ret void +; diff --git a/llpc/test/shaderdb/gfx11/cooperativeMatrix/lit.local.cfg b/llpc/test/shaderdb/gfx11/cooperativeMatrix/lit.local.cfg new file mode 100644 index 0000000000..a4266bc874 --- /dev/null +++ b/llpc/test/shaderdb/gfx11/cooperativeMatrix/lit.local.cfg @@ -0,0 +1,2 @@ +if "vki_cooperative_matrix" not in config.available_features: + config.unsupported = True diff --git a/llpc/test/shaderdb/gfx11/cooperativeMatrix/loadstore-uvec4.comp b/llpc/test/shaderdb/gfx11/cooperativeMatrix/loadstore-uvec4.comp new file mode 100644 index 0000000000..c97078f7ac --- /dev/null +++ b/llpc/test/shaderdb/gfx11/cooperativeMatrix/loadstore-uvec4.comp @@ -0,0 +1,30 @@ +// NOTE: Assertions have been autogenerated by tool/update_llpc_test_checks.py +// RUN: amdllpc -o - -gfxip 11.0 -emit-lgc %s | FileCheck -check-prefixes=CHECK %s + +// Test that the stride of cooperative matrix load/store operations is handled +// correctly when the array element type is not the same as the matrix element +// type. +#version 450 core +#pragma use_vulkan_memory_model +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float32: enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16: enable + +layout(set=0, binding=0, std430) buffer Input { uvec4 x[]; } bufIn; +layout(set=0, binding=1, std430) buffer Output { uvec4 x[]; } bufOut; + +void main() { + coopmat matrix; + coopMatLoad(matrix, bufIn.x, 0, 4, 0); + coopMatStore(matrix, bufOut.x, 0, 4, 0); +} + +// CHECK-LABEL: @lgc.shader.CS.main( +// CHECK-LABEL: .entry: +// CHECK-NEXT: [[TMP0:%[0-9]*]] = call ptr addrspace(7) (...) @lgc.create.load.buffer.desc.p7(i64 0, i32 1, i32 0, i32 2) +// CHECK-NEXT: [[TMP1:%[0-9]*]] = call ptr addrspace(7) (...) @lgc.create.load.buffer.desc.p7(i64 0, i32 0, i32 0, i32 2) +// CHECK-NEXT: [[TMP2:%[0-9]*]] = call <8 x float> @lgc.cooperative.matrix.load.v8f32.p7.i32.i1.i32.i32.i32(ptr addrspace(7) [[TMP1]], i32 64, i1 true, i32 1, i32 0, i32 0) #[[ATTR1:[0-9]+]] +// CHECK-NEXT: call void @lgc.cooperative.matrix.store.p7.i32.i1.i32.i32.i32.v8f32(ptr addrspace(7) [[TMP0]], i32 64, i1 true, i32 1, i32 0, i32 0, <8 x float> [[TMP2]]) #[[ATTR2:[0-9]+]] +// CHECK-NEXT: ret void +// diff --git a/llpc/test/shaderdb/object/ObjInput_TestUnUsedVariable_lit.comp b/llpc/test/shaderdb/object/ObjInput_TestUnUsedVariable_lit.comp new file mode 100644 index 0000000000..947a38dbe3 --- /dev/null +++ b/llpc/test/shaderdb/object/ObjInput_TestUnUsedVariable_lit.comp @@ -0,0 +1,20 @@ +#version 450 core + +layout (set=0, binding=0) buffer MyBuffer { + float n; +} myBuffer; + +layout (set=0, binding=0) uniform sampler mySampler; +layout (set=0, binding=1) uniform texture2D myTexture; + +void main() +{ + textureLod(sampler2D(myTexture, mySampler), vec2(0.0, 0.0), 0).x; +} + +// BEGIN_SHADERTEST +/* +; RUN: amdllpc -v --auto-layout-desc %gfxip %s | FileCheck -check-prefix=SHADERTEST %s +; SHADERTEST: AMDLLPC SUCCESS +*/ +// END_SHADERTEST diff --git a/llpc/tool/llpcAutoLayout.cpp b/llpc/tool/llpcAutoLayout.cpp index 04f7aa435d..65d656677a 100644 --- a/llpc/tool/llpcAutoLayout.cpp +++ b/llpc/tool/llpcAutoLayout.cpp @@ -41,7 +41,9 @@ #include "llpcCompilationUtils.h" #include "llpcDebug.h" #include "llpcUtil.h" +#include "spvgen.h" #include "vfx.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Format.h" #define DEBUG_TYPE "llpc-auto-layout" @@ -251,8 +253,29 @@ bool checkPipelineStateCompatible(const ICompiler *compiler, Llpc::GraphicsPipel void doAutoLayoutDesc(ShaderStage shaderStage, BinaryData spirvBin, GraphicsPipelineBuildInfo *pipelineInfo, PipelineShaderInfo *shaderInfo, ResourceMappingNodeMap &resNodeSets, unsigned &pushConstSize, bool autoLayoutDesc, bool reverseThreadGroup) { + + const void *spvBuf = spirvBin.pCode; + unsigned spvBufSize = spirvBin.codeSize; + + // Remove the unused variables. + void *optBuf = nullptr; + unsigned optBufSize = 0; + const char *options[] = {"--remove-unused-interface-variables", "--eliminate-dead-variables"}; + bool ret = spvOptimizeSpirv(spirvBin.codeSize, spirvBin.pCode, sizeof(options) / sizeof(options[0]), options, + &optBufSize, &optBuf, 0, nullptr); + if (ret) { + spvBuf = optBuf; + spvBufSize = optBufSize; + } + + // Release optimized spirv data. + auto freeSpvData = make_scope_exit([&] { + if (ret) + free(optBuf); + }); + // Read the SPIR-V. - std::string spirvCode(static_cast(spirvBin.pCode), spirvBin.codeSize); + std::string spirvCode(static_cast(spvBuf), spvBufSize); std::istringstream spirvStream(spirvCode); std::unique_ptr module(SPIRVModule::createSPIRVModule()); spirvStream >> *module; diff --git a/llpc/translator/lib/SPIRV/SPIRVReader.cpp b/llpc/translator/lib/SPIRV/SPIRVReader.cpp index 8011085926..8cc62a367e 100644 --- a/llpc/translator/lib/SPIRV/SPIRVReader.cpp +++ b/llpc/translator/lib/SPIRV/SPIRVReader.cpp @@ -834,6 +834,26 @@ Type *SPIRVToLLVM::transTypeWithOpcode(SPIRVType *const spvType, c return FixedVectorType::get(compType, spvType->getVectorComponentCount()); } +// ===================================================================================================================== +// Translate an "OpTypeCooperativeMatrixKHR". +// @param spvType : The type. +// @param matrixStride : The matrix stride (can be 0). +// @param isColumnMajor : Whether the matrix is column major. +// @param isParentPointer : If the parent is a pointer type. +// @param layout : The layout mode will be used for the type translation. +template <> +Type *SPIRVToLLVM::transTypeWithOpcode(SPIRVType *const spvType, + const unsigned matrixStride, + const bool isColumnMajor, const bool isParentPointer, + LayoutMode layout) { + auto elemType = mapToBasicType(spvType->getCooperativeMatrixKHRComponentType()); + auto use = spvType->getCooperativeMatrixKHRUse(); + unsigned rows = spvType->getCooperativeMatrixKHRRows(); + unsigned columns = spvType->getCooperativeMatrixKHRColumns(); + auto matrixLayout = getCooperativeMatrixKHRLayout(static_cast(use), elemType, rows, columns); + return getBuilder()->getCooperativeMatrixTy(elemType, matrixLayout); +} + // ===================================================================================================================== // Get pointee type from SPIRV Value. // @@ -978,6 +998,11 @@ Type *SPIRVToLLVM::transTypeImpl(SPIRVType *t, unsigned matrixStride, bool colum Type *newTy = transTypeWithOpcode(t, matrixStride, columnMajor, parentIsPointer, layout); return parentIsPointer ? newTy : mapType(t, newTy); } + case OpTypeCooperativeMatrixKHR: { + Type *newTy = + transTypeWithOpcode(t, matrixStride, columnMajor, parentIsPointer, layout); + return parentIsPointer ? newTy : mapType(t, newTy); + } default: { llvm_unreachable("Not implemented"); } @@ -1109,6 +1134,25 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *bv, Function *f, BasicBlock *bb auto dstType = transType(bc->getType()); CastInst::CastOps co = Instruction::BitCast; + lgc::Builder::CooperativeMatrixElementType srcElemTy = lgc::Builder::CooperativeMatrixElementType::Unknown; + lgc::Builder::CooperativeMatrixElementType dstElemTy = lgc::Builder::CooperativeMatrixElementType::Unknown; + lgc::Builder::CooperativeMatrixLayout srcLayout = lgc::Builder::CooperativeMatrixLayout::InvalidLayout; + lgc::Builder::CooperativeMatrixLayout dstLayout = lgc::Builder::CooperativeMatrixLayout::InvalidLayout; + + if (bv->getType()->isTypeCooperativeMatrixKHR()) { + auto srcCompType = static_cast(bc->getOperand(0)->getType()) + ->getCooperativeMatrixKHRComponentType(); + srcElemTy = mapToBasicType(srcCompType); + auto dstCompType = + static_cast(bc->getType())->getCooperativeMatrixKHRComponentType(); + dstElemTy = mapToBasicType(dstCompType); + auto dstUse = static_cast(bc->getType())->getCooperativeMatrixKHRUse(); + unsigned rows = static_cast(bc->getType())->getCooperativeMatrixKHRRows(); + unsigned columns = static_cast(bc->getType())->getCooperativeMatrixKHRColumns(); + dstLayout = getCooperativeMatrixKHRLayout(static_cast(dstUse), dstElemTy, rows, columns); + srcLayout = getCooperativeMatrixKHRLayout(static_cast(dstUse), srcElemTy, rows, columns); + } + bool isExt = dstType->getScalarSizeInBits() > srcType->getScalarSizeInBits(); switch (bc->getOpCode()) { case OpSConvert: @@ -1128,6 +1172,9 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *bv, Function *f, BasicBlock *bb return src; assert(CastInst::isCast(co) && "Invalid cast op code"); if (bb) { + if (bv->getType()->isTypeCooperativeMatrixKHR()) { + return getBuilder()->CreateCooperativeMatrixConvert(co, src, srcElemTy, dstElemTy, srcLayout, dstLayout); + } bool srcIsPtr = srcType->isPtrOrPtrVectorTy(); bool dstIsPtr = dstType->isPtrOrPtrVectorTy(); // OpBitcast in SPIR-V allows casting between pointers and integers (and integer vectors), @@ -1255,6 +1302,8 @@ void SPIRVToLLVM::setFastMathFlags(Value *val) { Value *SPIRVToLLVM::transShiftLogicalBitwiseInst(SPIRVValue *bv, BasicBlock *bb, Function *f) { SPIRVBinary *bbn = static_cast(bv); assert(bb && "Invalid BB"); + if (bbn->getOperand(0)->getType()->isTypeCooperativeMatrixKHR()) + return transCooperativeMatrixArithInst(bv, bb); Instruction::BinaryOps bo; auto op = bbn->getOpCode(); if (isLogicalOpCode(op)) @@ -3153,6 +3202,30 @@ template <> Value *SPIRVToLLVM::transValueWithOpcode(SPIRVValue * spvAccessType = spvAccessType->getVectorComponentType(); break; } + case OpTypeCooperativeMatrixKHR: { + flushGep(); + auto use = spvAccessType->getCooperativeMatrixKHRUse(); + unsigned rows = spvAccessType->getCooperativeMatrixKHRRows(); + unsigned columns = spvAccessType->getCooperativeMatrixKHRColumns(); + spvAccessType = spvAccessType->getCooperativeMatrixKHRComponentType(); + basePointeeType = transType(spvAccessType); + lgc::BuilderCommon::CooperativeMatrixElementType elemType = mapToBasicType(spvAccessType); + lgc::BuilderCommon::CooperativeMatrixLayout layout = + getCooperativeMatrixKHRLayout(static_cast(use), elemType, rows, columns); + + std::string mangledName(LlpcName::SpirvCooperativeMatrixProxy); + Value *args[] = { + base, + getBuilder()->getInt32((unsigned)elemType), + getBuilder()->getInt32((unsigned)layout), + }; + Type *retType = basePointeeType->getPointerTo(base->getType()->getPointerAddressSpace()); + appendTypeMangling(retType, args, mangledName); + base = getBuilder()->CreateNamedCall(mangledName, retType, args, {Attribute::ReadNone, Attribute::NoUnwind}); + + gepIndices[0] = index; + break; + } default: llvm_unreachable("unhandled type in access chain"); } @@ -4700,7 +4773,18 @@ template <> Value *SPIRVToLLVM::transValueWithOpcode(SPIRVV Function *const func = getBuilder()->GetInsertBlock()->getParent(); Value *const matrix = transValue(spvOperands[0], func, block); Value *const scalar = transValue(spvOperands[1], func, block); - { return getBuilder()->CreateMatrixTimesScalar(matrix, scalar); } + if (spvOperands[0]->getType()->isTypeCooperativeMatrixKHR()) { + SPIRVType *elemSpvType = spvOperands[0]->getType()->getCooperativeMatrixKHRComponentType(); + unsigned rows = spvOperands[0]->getType()->getCooperativeMatrixKHRRows(); + unsigned columns = spvOperands[0]->getType()->getCooperativeMatrixKHRColumns(); + lgc::Builder::CooperativeMatrixElementType elemType = mapToBasicType(elemSpvType); + lgc::Builder::CooperativeMatrixLayout layout = getCooperativeMatrixKHRLayout( + static_cast(spvOperands[0]->getType()->getCooperativeMatrixKHRUse()), elemType, rows, + columns); + return getBuilder()->CreateCoopMatrixTimesScalar(matrix, scalar, elemType, layout); + } else { + return getBuilder()->CreateMatrixTimesScalar(matrix, scalar); + } } // ===================================================================================================================== @@ -4790,6 +4874,313 @@ Value *SPIRVToLLVM::transString(const SPIRVString *spvValue) { return mapEntry(spvValue, global); } +// ===================================================================================================================== +// Handle mapToBasicType: translate the element type +// @param elemType : A SPIR-V type. +// | A/B type | C/D type | gfx11 | LGC | +// |----------|----------|-------|-----| +// | f16 | f32 | Y | Y | +// | bf16 | f32 | Y | N | +// | f16 | f16 | Y | Y | +// | bf16 | bf16 | Y | N | +// | iu8 | i32 | Y | Y | +// | iu4 | i32 | Y | N | +// For integer types, arbitrary signedness combinations are supported for the +// A/B matrices.C/D matrices are always signed. + +lgc::Builder::CooperativeMatrixElementType SPIRVToLLVM::mapToBasicType(Type *const elemType) { + lgc::Builder::CooperativeMatrixElementType basicTy = lgc::Builder::CooperativeMatrixElementType::Unknown; + if (elemType->isIntegerTy(8)) { + basicTy = lgc::Builder::CooperativeMatrixElementType::Int8; + } else if (elemType->isIntegerTy(16)) { + basicTy = lgc::Builder::CooperativeMatrixElementType::Int16; + } else if (elemType->isIntegerTy(32)) { + basicTy = lgc::Builder::CooperativeMatrixElementType::Int32; + } else if (elemType->isFloatTy()) { + basicTy = lgc::Builder::CooperativeMatrixElementType::Float32; + } else if (elemType->isHalfTy()) { + basicTy = lgc::Builder::CooperativeMatrixElementType::Float16; + } else { + llvm_unreachable("The element type is not supported!"); + } + return basicTy; +} + +lgc::Builder::CooperativeMatrixElementType SPIRVToLLVM::mapToBasicType(SPIRVType *const elemType) { + lgc::Builder::CooperativeMatrixElementType basicTy = lgc::Builder::CooperativeMatrixElementType::Unknown; + if (elemType->isTypeInt(8)) { + basicTy = lgc::Builder::CooperativeMatrixElementType::Int8; + } else if (elemType->isTypeInt(16)) { + basicTy = lgc::Builder::CooperativeMatrixElementType::Int16; + } else if (elemType->isTypeInt(32)) { + basicTy = lgc::Builder::CooperativeMatrixElementType::Int32; + } else if (elemType->isTypeFloat(32)) { + basicTy = lgc::Builder::CooperativeMatrixElementType::Float32; + } else if (elemType->isTypeFloat(16)) { + basicTy = lgc::Builder::CooperativeMatrixElementType::Float16; + } else { + llvm_unreachable("The element type is not supported!"); + } + return basicTy; +} + +lgc::Builder::CooperativeMatrixLayout SPIRVToLLVM::getLayout(lgc::Builder::CooperativeMatrixElementType elemType) { + const Vkgc::GfxIpVersion gfxIp = static_cast(m_context)->getPipelineContext()->getGfxIpVersion(); + if (elemType == lgc::Builder::CooperativeMatrixElementType::Int32 || + elemType == lgc::Builder::CooperativeMatrixElementType::Float32) { + if (gfxIp.major == 11) + return lgc::Builder::CooperativeMatrixLayout::AccumulatorMatrixLayout; + return lgc::Builder::CooperativeMatrixLayout::Gfx10AccumulatorMatrixLayout; + } + if (elemType == lgc::Builder::CooperativeMatrixElementType::Int16 || + elemType == lgc::Builder::CooperativeMatrixElementType::Int8 || + elemType == lgc::Builder::CooperativeMatrixElementType::Float16) { + return lgc::Builder::CooperativeMatrixLayout::FactorMatrixLayout; + } + llvm_unreachable("The element type is not supported!"); + return lgc::Builder::CooperativeMatrixLayout::InvalidLayout; +} + +// ===================================================================================================================== +// Mapping the use to layout +// @param use : CooperativeMatrixUse value. +// @param elemType : The type for the CooperativeMatrix element. +// @param rows: The size of the row for the CooperativeMatrix. +// @param columns: The size of the column for the CooperativeMatrix. +lgc::Builder::CooperativeMatrixLayout SPIRVToLLVM::getCooperativeMatrixKHRLayout( + CooperativeMatrixUse use, lgc::Builder::CooperativeMatrixElementType elemType, unsigned rows, unsigned columns) { + const Vkgc::GfxIpVersion gfxIp = static_cast(m_context)->getPipelineContext()->getGfxIpVersion(); + if (use == CooperativeMatrixUse::CooperativeMatrixUseMatrixAKHR || + use == CooperativeMatrixUse::CooperativeMatrixUseMatrixBKHR) { + return lgc::Builder::CooperativeMatrixLayout::FactorMatrixLayout; + } + if (use == CooperativeMatrixUse::CooperativeMatrixUseMatrixAccumulatorKHR) { + if (gfxIp.major == 11) + return lgc::Builder::CooperativeMatrixLayout::AccumulatorMatrixLayout; + if (elemType == lgc::Builder::CooperativeMatrixElementType::Float32 || + elemType == lgc::Builder::CooperativeMatrixElementType::Int32) + return lgc::Builder::CooperativeMatrixLayout::Gfx10AccumulatorMatrixLayout; + if (elemType == lgc::Builder::CooperativeMatrixElementType::Int16 || + elemType == lgc::Builder::CooperativeMatrixElementType::Float16) + return lgc::Builder::CooperativeMatrixLayout::Gfx10Accumulator16bitMatrixLayout; + llvm_unreachable("Invalid element type!"); + return lgc::Builder::CooperativeMatrixLayout::InvalidLayout; + } + llvm_unreachable("The element type is not supported!"); + return lgc::Builder::CooperativeMatrixLayout::InvalidLayout; +} + +// ===================================================================================================================== +// Handle OpCooperativeMatrixLengthKHR. +// @param spvValue : A SPIR-V value. +template <> Value *SPIRVToLLVM::transValueWithOpcode(SPIRVValue *const spvValue) { + SPIRVCooperativeMatrixLengthKHR *const coopMatLength = static_cast(spvValue); + SPIRVType *matrixType = coopMatLength->getMatrixType(); + CooperativeMatrixUse matrixUse = static_cast(matrixType->getCooperativeMatrixKHRUse()); + auto elemType = mapToBasicType(matrixType->getCooperativeMatrixKHRComponentType()); + unsigned rows = matrixType->getCooperativeMatrixKHRRows(); + unsigned columns = matrixType->getCooperativeMatrixKHRColumns(); + auto layout = getCooperativeMatrixKHRLayout(matrixUse, elemType, rows, columns); + return getBuilder()->CreateCooperativeMatrixLength(elemType, layout); +} + +// ===================================================================================================================== +// Handle OpCooperativeMatrixLoadKHR. +// @param spvValue : A SPIR-V value. +template <> Value *SPIRVToLLVM::transValueWithOpcode(SPIRVValue *const spvValue) { + SPIRVCooperativeMatrixLoadKHR *const coopMatLoad = static_cast(spvValue); + Function *fn = getBuilder()->GetInsertBlock()->getParent(); + BasicBlock *bb = getBuilder()->GetInsertBlock(); + + // Get Value: the pointer/strider/colMajor + Value *pointer = transValue(coopMatLoad->getSrc(), fn, bb); + Value *arrayStride = transValue(coopMatLoad->getStride(), fn, bb); + Value *colMajor = transValue(coopMatLoad->getColMajor(), fn, bb); + + // The lgc operation expects the stride to be in bytes. + auto pointeeSize = m_m->getDataLayout().getTypeStoreSize(getPointeeType(coopMatLoad->getSrc())).getFixedValue(); + assert(pointeeSize != 0 && "OpCooperativeMatrixLoadKHR pointee must be a scalar or vector"); + + Value *stride = getBuilder()->CreateMul(arrayStride, getBuilder()->getInt32(pointeeSize)); + + // Calc memoryAccess + unsigned memoryAccess = CooperativeMatrixMemoryAccessNone; + // Calc isVolatile + bool isVolatile = coopMatLoad->SPIRVMemoryAccess::isVolatile(true); + const Vkgc::ExtendedRobustness &extendedRobustness = getPipelineOptions()->extendedRobustness; + if (extendedRobustness.nullDescriptor || extendedRobustness.robustBufferAccess) + isVolatile |= coopMatLoad->getSrc()->isVolatile(); + // We don't require volatile on address spaces that become non-pointers. + switch (coopMatLoad->getSrc()->getType()->getPointerStorageClass()) { + case StorageClassInput: + case StorageClassOutput: + case StorageClassPrivate: + case StorageClassFunction: + isVolatile = false; + break; + default: + break; + } + + // Calc isCoherent + bool isCoherent = coopMatLoad->getSrc()->isCoherent(); + // MakePointerVisibleKHR is valid with OpCooperativeMatrixLoadKHR + if (coopMatLoad->getMemoryAccessMask(true) & MemoryAccessMakePointerVisibleKHRMask) { + SPIRVWord spvId = coopMatLoad->getMakeVisibleScope(true); + SPIRVConstant *const spvScope = static_cast(m_bm->getValue(spvId)); + const unsigned scope = spvScope->getZExtIntValue(); + const bool isSystemScope = (scope <= ScopeDevice || scope == ScopeQueueFamilyKHR); + if (isSystemScope) + isCoherent = true; + } + if (coopMatLoad->getMemoryAccessMask(true) & MemoryAccessNonPrivatePointerKHRMask) + isCoherent = true; + + // Calc isNonTempal + const bool isNonTemporal = coopMatLoad->SPIRVMemoryAccess::isNonTemporal(true); + if (isVolatile) { + memoryAccess |= CooperativeMatrixMemoryAccessVolatile; + } + if (isCoherent) { + memoryAccess |= CooperativeMatrixMemoryAccessCoherent; + } + if (isNonTemporal) { + memoryAccess |= CooperativeMatrixMemoryAccessTemporal; + } + + bool isColMajor = cast(colMajor)->getZExtValue(); + // Cal elemType + SPIRVType *elemSpvType = coopMatLoad->getType()->getCooperativeMatrixKHRComponentType(); + CooperativeMatrixUse use = static_cast(coopMatLoad->getType()->getCooperativeMatrixKHRUse()); + unsigned rows = static_cast(coopMatLoad->getType()->getCooperativeMatrixKHRRows()); + unsigned columns = static_cast(coopMatLoad->getType()->getCooperativeMatrixKHRColumns()); + if (use == CooperativeMatrixUse::CooperativeMatrixUseMatrixAKHR) { + // Layout A is the transposition of the layout B, col_major_A = row_majow_B. + // FactorMatrixLayout is for B, so it needs inverse the layout when use is A. + isColMajor = !isColMajor; + } + lgc::Builder::CooperativeMatrixElementType elemType = mapToBasicType(elemSpvType); + lgc::Builder::CooperativeMatrixLayout layout = getCooperativeMatrixKHRLayout(use, elemType, rows, columns); + auto CoopMatLoadInst = + getBuilder()->CreateCooperativeMatrixLoad(pointer, stride, isColMajor, elemType, layout, memoryAccess); + return CoopMatLoadInst; +} + +// ===================================================================================================================== +// Handle OpCooperativeMatrixStoreKHR. +// @param spvValue : A SPIR-V value. +template <> Value *SPIRVToLLVM::transValueWithOpcode(SPIRVValue *const spvValue) { + SPIRVCooperativeMatrixStoreKHR *const coopMatStore = static_cast(spvValue); + Function *fn = getBuilder()->GetInsertBlock()->getParent(); + BasicBlock *bb = getBuilder()->GetInsertBlock(); + + // Get Value: the pointer/strider/colMajor + Value *pointer = transValue(coopMatStore->getDest(), fn, bb); + Value *matrix = transValue(coopMatStore->getObject(), fn, bb); + Value *arrayStride = transValue(coopMatStore->getStride(), fn, bb); + Value *colMajor = transValue(coopMatStore->getColMajor(), fn, bb); + + // The lgc operation expects the stride to be in bytes. + auto pointeeSize = m_m->getDataLayout().getTypeStoreSize(getPointeeType(coopMatStore->getDest())).getFixedValue(); + assert(pointeeSize != 0 && "OpCooperativeMatrixStoreKHR pointee must be a scalar or vector"); + + Value *stride = getBuilder()->CreateMul(arrayStride, getBuilder()->getInt32(pointeeSize)); + + // Calc isVolatile + bool isVolatile = coopMatStore->SPIRVMemoryAccess::isVolatile(false); + const Vkgc::ExtendedRobustness &extendedRobustness = getPipelineOptions()->extendedRobustness; + if (extendedRobustness.nullDescriptor || extendedRobustness.robustBufferAccess) + isVolatile |= coopMatStore->getDest()->isVolatile(); + // We don't require volatile on address spaces that become non-pointers. + const auto pointerStorageClass = coopMatStore->getDest()->getType()->getPointerStorageClass(); + switch (pointerStorageClass) { + case StorageClassInput: + case StorageClassOutput: + case StorageClassPrivate: + case StorageClassFunction: + isVolatile = false; + break; + default: + break; + } + + // Calc isCoherent + bool isCoherent = coopMatStore->getDest()->isCoherent(); + // MakePointerAvailableKHR is valid with OpStore + if (coopMatStore->getMemoryAccessMask(false) & MemoryAccessMakePointerAvailableKHRMask) { + SPIRVWord spvId = coopMatStore->getMakeAvailableScope(false); + SPIRVConstant *const spvScope = static_cast(m_bm->getValue(spvId)); + const unsigned scope = spvScope->getZExtIntValue(); + const bool isSystemScope = (scope <= ScopeDevice || scope == ScopeQueueFamilyKHR); + if (isSystemScope) + isCoherent = true; + } + if (coopMatStore->getMemoryAccessMask(false) & MemoryAccessNonPrivatePointerKHRMask) + isCoherent = true; + + // Calc isNonTempal + const bool isNonTemporal = coopMatStore->SPIRVMemoryAccess::isNonTemporal(false); + + // Calc colMajor + bool isColMajor = cast(colMajor)->getZExtValue(); + + // Cal elemType + Type *const elemltType = transType(coopMatStore->getObject()->getType()->getCooperativeMatrixKHRComponentType()); + lgc::Builder::CooperativeMatrixElementType elemType = mapToBasicType(elemltType); + + CooperativeMatrixUse use = + static_cast(coopMatStore->getObject()->getType()->getCooperativeMatrixKHRUse()); + unsigned rows = coopMatStore->getObject()->getType()->getCooperativeMatrixKHRRows(); + unsigned columns = coopMatStore->getObject()->getType()->getCooperativeMatrixKHRColumns(); + + if (use == CooperativeMatrixUse::CooperativeMatrixUseMatrixAKHR) { + // Layout A is the transposition of the layout B, col_major_A = row_majow_B. + // FactorMatrixLayout is for B, so it needs inverse the layout when use is A. + isColMajor = !isColMajor; + } + lgc::Builder::CooperativeMatrixLayout layout = getCooperativeMatrixKHRLayout(use, elemType, rows, columns); + unsigned memoryAccess = CooperativeMatrixMemoryAccessNone; + if (isVolatile) { + memoryAccess |= CooperativeMatrixMemoryAccessVolatile; + } + if (isCoherent) { + memoryAccess |= CooperativeMatrixMemoryAccessCoherent; + } + if (isNonTemporal) { + memoryAccess |= CooperativeMatrixMemoryAccessTemporal; + } + + getBuilder()->CreateCooperativeMatrixStore(pointer, matrix, stride, isColMajor, elemType, layout, memoryAccess); + return nullptr; +} + +// ===================================================================================================================== +// Handle OpCooperativeMatrixMulAddKHR. +// @param spvValue : A SPIR-V value. +template <> Value *SPIRVToLLVM::transValueWithOpcode(SPIRVValue *const spvValue) { + SPIRVInstruction *const spvInst = static_cast(spvValue); + std::vector spvOperands = spvInst->getOperands(); + BasicBlock *const block = getBuilder()->GetInsertBlock(); + Function *const func = getBuilder()->GetInsertBlock()->getParent(); + Value *coopMatrixA = transValue(spvOperands[0], func, block); + Value *coopMatrixB = transValue(spvOperands[1], func, block); + Value *coopMatrixC = transValue(spvOperands[2], func, block); + + SPIRVType *elemTypeA = spvOperands[0]->getType()->getCooperativeMatrixKHRComponentType(); + SPIRVType *elemTypeC = spvOperands[2]->getType()->getCooperativeMatrixKHRComponentType(); + + lgc::Builder::CooperativeMatrixElementType elemBasicTypeA = mapToBasicType(elemTypeA); + lgc::Builder::CooperativeMatrixElementType elemBasicTypeC = mapToBasicType(elemTypeC); + + bool isSignedA = static_cast(static_cast(spvInst)->getMatrixASigned()); + bool isSignedB = static_cast(static_cast(spvInst)->getMatrixBSigned()); + bool isSat = static_cast(static_cast(spvInst)->getMatrixSatAccumulation()); + + Value *coopMatrixD = getBuilder()->CreateCooperativeMatrixMulAdd(coopMatrixA, coopMatrixB, coopMatrixC, isSignedA, + isSignedB, isSat, elemBasicTypeC, elemBasicTypeA); + return coopMatrixD; +} + /// For instructions, this function assumes they are created in order /// and appended to the given basic block. An instruction may use a /// instruction from another BB which has not been translated. Such @@ -4896,6 +5287,10 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *bv, Function *f, Bas case OpTypeMatrix: { return mapValue(bv, ConstantArray::get(dyn_cast(transType(bcc->getType())), cv)); } + case OpTypeCooperativeMatrixKHR: { + auto elements = transValue(bcc->getElements(), f, bb); + return mapValue(bv, transCooperativeMatrixKHRFromConstruct(bcc->getType(), elements)); + } default: llvm_unreachable("not implemented"); return nullptr; @@ -5207,6 +5602,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *bv, Function *f, Bas v = InsertValueInst::Create(v, constituents[i], i, "", bb); return mapValue(bv, v); } + case OpTypeCooperativeMatrixKHR: { + return mapValue(bv, transCooperativeMatrixKHRFromConstruct(cc->getType(), constituents)); + } default: llvm_unreachable("Unhandled type!"); } @@ -5220,6 +5618,19 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *bv, Function *f, Bas ConstantInt::get(*m_context, APInt(32, ce->getIndices()[0])), bv->getName(), bb)); } + if (ce->getComposite()->getType()->isTypeCooperativeMatrixKHR()) { + assert(ce->getIndices().size() == 1 && "Invalid index"); // Treating it as vector. + + SPIRVType *matrixType = ce->getComposite()->getType(); + auto elemType = mapToBasicType(matrixType->getCooperativeMatrixKHRComponentType()); + unsigned rows = matrixType->getCooperativeMatrixKHRRows(); + unsigned columns = matrixType->getCooperativeMatrixKHRColumns(); + auto layout = getCooperativeMatrixKHRLayout( + static_cast(matrixType->getCooperativeMatrixKHRUse()), elemType, rows, columns); + Value *matrix = transValue(ce->getComposite(), f, bb); + Value *index = getBuilder()->getInt32(ce->getIndices()[0]); + return mapValue(bv, getBuilder()->CreateCooperativeMatrixExtract(matrix, index, elemType, layout)); + } auto cv = transValue(ce->getComposite(), f, bb); auto indexedTy = ExtractValueInst::getIndexedType(cv->getType(), ce->getIndices()); @@ -5256,6 +5667,22 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *bv, Function *f, Bas ConstantInt::get(*m_context, APInt(32, ci->getIndices()[0])), bv->getName(), bb)); } + if (ci->getComposite()->getType()->isTypeCooperativeMatrixKHR()) { + assert(ci->getIndices().size() == 1 && "Invalid index"); // Treating it as vector. + + SPIRVType *matrixType = ci->getComposite()->getType(); + auto elemType = mapToBasicType(matrixType->getCooperativeMatrixKHRComponentType()); + unsigned rows = matrixType->getCooperativeMatrixKHRRows(); + unsigned columns = matrixType->getCooperativeMatrixKHRColumns(); + auto layout = getCooperativeMatrixKHRLayout( + static_cast(matrixType->getCooperativeMatrixKHRUse()), elemType, rows, columns); + + Value *matrix = transValue(ci->getComposite(), f, bb); + Value *value = transValue(ci->getObject(), f, bb); + Value *index = getBuilder()->getInt32(ci->getIndices()[0]); + return mapValue(bv, getBuilder()->CreateCooperativeMatrixInsert(matrix, value, index, elemType, layout)); + } + auto cv = transValue(ci->getComposite(), f, bb); auto indexedTy = ExtractValueInst::getIndexedType(cv->getType(), ci->getIndices()); if (!indexedTy) { @@ -5352,22 +5779,34 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *bv, Function *f, Bas return mapValue(bv, transBarrierFence(static_cast(bv), bb)); case OpSNegate: { + if (bv->getType()->isTypeCooperativeMatrixKHR()) { + return mapValue(bv, transCooperativeMatrixArithInst(bv, bb)); + } SPIRVUnary *bc = static_cast(bv); return mapValue(bv, BinaryOperator::CreateNSWNeg(transValue(bc->getOperand(0), f, bb), bv->getName(), bb)); } case OpSMod: { + if (bv->getType()->isTypeCooperativeMatrixKHR()) { + return mapValue(bv, transCooperativeMatrixArithInst(bv, bb)); + } SPIRVBinary *bc = static_cast(bv); Value *val0 = transValue(bc->getOperand(0), f, bb); Value *val1 = transValue(bc->getOperand(1), f, bb); return mapValue(bc, getBuilder()->CreateSMod(val0, val1)); } case OpFMod: { + if (bv->getType()->isTypeCooperativeMatrixKHR()) { + return mapValue(bv, transCooperativeMatrixArithInst(bv, bb)); + } SPIRVFMod *bc = static_cast(bv); Value *val0 = transValue(bc->getDividend(), f, bb); Value *val1 = transValue(bc->getDivisor(), f, bb); return mapValue(bc, getBuilder()->CreateFMod(val0, val1)); } case OpFNegate: { + if (bv->getType()->isTypeCooperativeMatrixKHR()) { + return mapValue(bv, transCooperativeMatrixArithInst(bv, bb)); + } SPIRVUnary *bc = static_cast(bv); Value *val0 = transValue(bc->getOperand(0), f, bb); auto fNeg = getBuilder()->CreateFNeg(val0); @@ -5382,6 +5821,24 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *bv, Function *f, Bas SPIRVUnary *bc = static_cast(bv); Value *val = transValue(bc->getOperand(0), f, bb); Type *destTy = transType(bc->getType()); + // Can't use destTy as for transType will return packed element Type. + CastInst::CastOps co = Instruction::BitCast; + if (bv->getType()->isTypeCooperativeMatrixKHR()) { + SPIRVType *dstType = bc->getType()->getCooperativeMatrixKHRComponentType(); + lgc::Builder::CooperativeMatrixElementType basicDstElemTy = mapToBasicType(dstType); + SPIRVType *srcType = bc->getOperand(0)->getType()->getCooperativeMatrixKHRComponentType(); + bool isExt = dstType->getBitWidth() > srcType->getBitWidth(); + co = isExt ? Instruction::FPExt : Instruction::FPTrunc; + lgc::Builder::CooperativeMatrixElementType basicSrcElemTy = mapToBasicType(srcType); + lgc::Builder::CooperativeMatrixLayout srcLayout = getCooperativeMatrixKHRLayout( + static_cast(bc->getType()->getCooperativeMatrixKHRUse()), basicSrcElemTy, + bc->getType()->getCooperativeMatrixKHRRows(), bc->getType()->getCooperativeMatrixKHRColumns()); + lgc::Builder::CooperativeMatrixLayout dstLayout = getCooperativeMatrixKHRLayout( + static_cast(bc->getType()->getCooperativeMatrixKHRUse()), basicDstElemTy, + bc->getType()->getCooperativeMatrixKHRRows(), bc->getType()->getCooperativeMatrixKHRColumns()); + return mapValue(bv, getBuilder()->CreateCooperativeMatrixConvert(co, val, basicSrcElemTy, basicDstElemTy, + srcLayout, dstLayout)); + } if (val->getType()->getScalarType()->getPrimitiveSizeInBits() <= destTy->getScalarType()->getPrimitiveSizeInBits()) return mapValue(bv, getBuilder()->CreateFPExt(val, destTy)); @@ -5888,6 +6345,14 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *bv, Function *f, Bas return transValueWithOpcode(bv); case OpSetMeshOutputsEXT: return transValueWithOpcode(bv); + case OpCooperativeMatrixLengthKHR: + return mapValue(bv, transValueWithOpcode(bv)); + case OpCooperativeMatrixLoadKHR: + return mapValue(bv, transValueWithOpcode(bv)); + case OpCooperativeMatrixStoreKHR: + return mapValue(bv, transValueWithOpcode(bv)); + case OpCooperativeMatrixMulAddKHR: + return mapValue(bv, transValueWithOpcode(bv)); default: { auto oc = bv->getOpCode(); if (isSPIRVCmpInstTransToLLVMInst(static_cast(bv))) @@ -8082,9 +8547,10 @@ bool SPIRVToLLVM::transShaderDecoration(SPIRVValue *bv, Value *v) { auto locationFound = getUniformConstantEntryByLocation(static_cast(m_context), convertToShaderStage(m_execModule), loc); unsigned offset = locationFound == nullptr ? 0 : locationFound->offset; + unsigned constantBufferBinding = Vkgc::ConstantBuffer0Binding + m_shaderOptions->constantBufferBindingOffset; mDs.push_back(ConstantAsMetadata::get(ConstantInt::get(int32Ty, Vkgc::InternalDescriptorSetId))); - mDs.push_back(ConstantAsMetadata::get(ConstantInt::get(int32Ty, Vkgc::ConstantBuffer0Binding))); + mDs.push_back(ConstantAsMetadata::get(ConstantInt::get(int32Ty, constantBufferBinding))); mDs.push_back(ConstantAsMetadata::get(ConstantInt::get(int32Ty, offset))); mDs.push_back(ConstantAsMetadata::get(ConstantInt::get(int32Ty, loc))); auto mdNode = MDNode::get(*m_context, mDs); @@ -9801,7 +10267,7 @@ void SPIRVToLLVM::createXfbMetadata(bool hasXfbOuts) { auto llpcContext = static_cast(m_context); auto pipelineBuildInfo = static_cast(llpcContext->getPipelineBuildInfo()); bool needXfbMetadata = hasXfbOuts && !pipelineBuildInfo->apiXfbOutData.forceDisableStreamOut; -#if LLPC_CLIENT_INTERFACE_MAJOR_VERSION < 69 +#if LLPC_CLIENT_INTERFACE_MAJOR_VERSION < 70 needXfbMetadata |= pipelineBuildInfo->apiXfbOutData.forceEnablePrimStats; #endif @@ -9922,6 +10388,143 @@ void SPIRVToLLVM::createXfbMetadata(bool hasXfbOuts) { entryFunc->setMetadata(mdKindId, MDNode::get(*m_context, metadatas)); } +// ============================================================================= +// Translate cooperative matrix instructions to LLVM IR +Value *SPIRVToLLVM::transCooperativeMatrixArithInst(SPIRVValue *spvVal, BasicBlock *bb) { + auto oc = spvVal->getOpCode(); + Function *func = bb->getParent(); + Builder::CooperativeMatrixArithOp arithOp; + switch (oc) { + case OpFNegate: + arithOp = Builder::CooperativeMatrixArithOp::FSub; + break; + case OpSNegate: + arithOp = Builder::CooperativeMatrixArithOp::ISub; + break; + case OpFAdd: + arithOp = Builder::CooperativeMatrixArithOp::FAdd; + break; + case OpIAdd: + arithOp = Builder::CooperativeMatrixArithOp::IAdd; + break; + case OpISub: + arithOp = Builder::CooperativeMatrixArithOp::ISub; + break; + case OpFSub: + arithOp = Builder::CooperativeMatrixArithOp::FSub; + break; + case OpIMul: + arithOp = Builder::CooperativeMatrixArithOp::IMul; + break; + case OpFMul: + arithOp = Builder::CooperativeMatrixArithOp::FMul; + break; + case OpFDiv: + arithOp = Builder::CooperativeMatrixArithOp::FDiv; + break; + case OpSDiv: + arithOp = Builder::CooperativeMatrixArithOp::SDiv; + break; + case OpUDiv: + arithOp = Builder::CooperativeMatrixArithOp::UDiv; + break; + case OpFMod: + arithOp = Builder::CooperativeMatrixArithOp::FMod; + break; + case OpSMod: + arithOp = Builder::CooperativeMatrixArithOp::SMod; + break; + case OpUMod: + arithOp = Builder::CooperativeMatrixArithOp::UMod; + break; + case OpSRem: + arithOp = Builder::CooperativeMatrixArithOp::SRem; + break; + case OpFRem: + arithOp = Builder::CooperativeMatrixArithOp::FRem; + break; + default: + llvm_unreachable("Not support arithmetic for cooperative matrix"); + return nullptr; + } + + lgc::Builder::CooperativeMatrixLayout layout = lgc::Builder::CooperativeMatrixLayout::InvalidLayout; + lgc::Builder::CooperativeMatrixElementType elemType = lgc::Builder::CooperativeMatrixElementType::Unknown; + if (oc == OpFNegate || oc == OpSNegate) { + auto unary = static_cast(spvVal); + Value *srcVal = transValue(unary->getOperand(0), func, bb); + if (unary->getOperand(0)->getType()->isTypeCooperativeMatrixKHR()) { + SPIRVType *elemSpvType = unary->getOperand(0)->getType()->getCooperativeMatrixKHRComponentType(); + unsigned rows = unary->getOperand(0)->getType()->getCooperativeMatrixKHRRows(); + unsigned columns = unary->getOperand(0)->getType()->getCooperativeMatrixKHRColumns(); + elemType = mapToBasicType(elemSpvType); + layout = getCooperativeMatrixKHRLayout( + static_cast(unary->getOperand(0)->getType()->getCooperativeMatrixKHRUse()), elemType, + rows, columns); + } + return getBuilder()->CreateCooperativeMatrixBinaryOp(arithOp, Constant::getNullValue(srcVal->getType()), srcVal, + elemType, layout); + } else { + auto binary = static_cast(spvVal); + Value *lhs = transValue(binary->getOperand(0), func, bb); + Value *rhs = transValue(binary->getOperand(1), func, bb); + if (binary->getOperand(0)->getType()->isTypeCooperativeMatrixKHR()) { + SPIRVType *elemSpvType = binary->getOperand(0)->getType()->getCooperativeMatrixKHRComponentType(); + unsigned rows = binary->getOperand(0)->getType()->getCooperativeMatrixKHRRows(); + unsigned columns = binary->getOperand(0)->getType()->getCooperativeMatrixKHRColumns(); + elemType = mapToBasicType(elemSpvType); + layout = getCooperativeMatrixKHRLayout( + static_cast(binary->getOperand(0)->getType()->getCooperativeMatrixKHRUse()), elemType, + rows, columns); + } + return getBuilder()->CreateCooperativeMatrixBinaryOp(arithOp, lhs, rhs, elemType, layout); + } +} + +// ============================================================================= +// Translate cooperative matrix construction instructions to LLVM IR +Value *SPIRVToLLVM::transCooperativeMatrixKHRFromConstruct(SPIRVType *spvCoopMatTy, + const std::vector &constituents) { + auto vecTy = transType(spvCoopMatTy); + Value *matrixResult = PoisonValue::get(vecTy); + unsigned subElemNums = 0; + unsigned elemNums = 0; + lgc::Builder::CooperativeMatrixElementType componentType = + mapToBasicType(spvCoopMatTy->getCooperativeMatrixKHRComponentType()); + unsigned duplicateFoldFactor = 1; + + switch (componentType) { + case lgc::Builder::CooperativeMatrixElementType::Int8: // A/B + subElemNums = 4; + elemNums = 4 / duplicateFoldFactor; + break; + case lgc::Builder::CooperativeMatrixElementType::Int32: // C/D + case lgc::Builder::CooperativeMatrixElementType::Float32: + subElemNums = 1; + elemNums = 8; // label:changewaveSize + break; + case lgc::Builder::CooperativeMatrixElementType::Int16: // A/B + case lgc::Builder::CooperativeMatrixElementType::Float16: + subElemNums = 2; + elemNums = 8 / duplicateFoldFactor; + break; + default: + llvm_unreachable("The component type is not be supported."); + } + + for (unsigned idx = 0; idx < elemNums; ++idx) { + Type *subElemTy = transType(spvCoopMatTy->getCooperativeMatrixKHRComponentType()); + Type *subVecTy = FixedVectorType::get(subElemTy, subElemNums); + Value *elem = PoisonValue::get(subVecTy); + for (unsigned subIdx = 0; subIdx < subElemNums; ++subIdx) + elem = getBuilder()->CreateInsertElement(elem, constituents[0], subIdx); // The value to initialize all members + + elem = getBuilder()->CreateBitCast(elem, cast(vecTy)->getElementType()); + matrixResult = getBuilder()->CreateInsertElement(matrixResult, elem, idx); + } + return matrixResult; +} + } // namespace SPIRV bool llvm::readSpirv(Builder *builder, const ShaderModuleUsage *shaderInfo, const PipelineShaderOptions *shaderOptions, diff --git a/llpc/translator/lib/SPIRV/SPIRVReader.h b/llpc/translator/lib/SPIRV/SPIRVReader.h index 69b733064b..7a6da16136 100644 --- a/llpc/translator/lib/SPIRV/SPIRVReader.h +++ b/llpc/translator/lib/SPIRV/SPIRVReader.h @@ -282,6 +282,23 @@ class SPIRVToLLVM { SmallVector llvmInstructions; }; + lgc::Builder::CooperativeMatrixElementType mapToBasicType(Type *const ltType); + lgc::Builder::CooperativeMatrixElementType mapToBasicType(SPIRVType *const spvType); + lgc::Builder::CooperativeMatrixLayout getLayout(lgc::Builder::CooperativeMatrixElementType elemTy); + lgc::Builder::CooperativeMatrixLayout getCooperativeMatrixKHRLayout(CooperativeMatrixUse use, + lgc::Builder::CooperativeMatrixElementType elemTy, + unsigned rows, unsigned columns); + + enum CooperativeMatrixMemoryAccess { + CooperativeMatrixMemoryAccessNone = 0x00, + CooperativeMatrixMemoryAccessVolatile = 0x01, + CooperativeMatrixMemoryAccessCoherent = 0x02, + CooperativeMatrixMemoryAccessTemporal = 0x04, + }; + + Value *transCooperativeMatrixArithInst(SPIRVValue *spvVal, BasicBlock *bb); + Value *transCooperativeMatrixKHRFromConstruct(SPIRVType *spvCoopMatRowTy, const std::vector &constituents); + // Stores pointers of LLVM Functions to SPIRV memops to the translated LLVM memop(s) in a MapVector to preserve // insertion order of the SPIRV memops and to preserve the function origins, as the bounds checks need to be // inserted on a per-function level. To handle dependencies between the LLVM IR memops, e. g. using a load result as diff --git a/llpc/translator/lib/SPIRV/libSPIRV/SPIRVEnum.h b/llpc/translator/lib/SPIRV/libSPIRV/SPIRVEnum.h index 4751beaa51..53cdddf9f5 100644 --- a/llpc/translator/lib/SPIRV/libSPIRV/SPIRVEnum.h +++ b/llpc/translator/lib/SPIRV/libSPIRV/SPIRVEnum.h @@ -210,6 +210,7 @@ template <> inline void SPIRVMap::init() { ADD_VEC_INIT(CapabilityDotProductInput4x8BitKHR, {CapabilityInt16}); ADD_VEC_INIT(CapabilityMeshShadingEXT, {CapabilityShader}); ADD_VEC_INIT(CapabilityFragmentBarycentricKHR, {CapabilityShader}); + ADD_VEC_INIT(CapabilityCooperativeMatrixKHR, {CapabilityShader}); } template <> inline void SPIRVMap::init() { diff --git a/llpc/translator/lib/SPIRV/libSPIRV/SPIRVInstruction.cpp b/llpc/translator/lib/SPIRV/libSPIRV/SPIRVInstruction.cpp index 32c608b958..5cededd88c 100644 --- a/llpc/translator/lib/SPIRV/libSPIRV/SPIRVInstruction.cpp +++ b/llpc/translator/lib/SPIRV/libSPIRV/SPIRVInstruction.cpp @@ -381,6 +381,28 @@ SPIRVValue *createValueFromSpecConstantOp(SPIRVSpecConstantOp *Inst, uint32_t Ro return constantCompositeInsert(Composite, Object, Indices); + } else if (OC == OpCooperativeMatrixLengthKHR) { + // TODO: This is subtly broken given different matrix layouts and wave sizes. Can we forbid use of this in a + // specification constant at the spec level? + // + // We work around this by: + // - producing whatever the maximum length is going to be (based on wave32) + // - adding some masking during lowering if we happen to compile for wave64, so that we don't access vectors + // out of bounds + assert(DestTy->isTypeScalar() && Ops.size() == 1); + SPIRVType *type = static_cast(BM->getEntry(Ops[0])); + assert(type->isTypeCooperativeMatrixKHR()); + (void)type; + unsigned length = 0; + CooperativeMatrixUse use = static_cast(type->getCooperativeMatrixKHRUse()); + if (use == CooperativeMatrixUse::CooperativeMatrixUseMatrixAKHR || + use == CooperativeMatrixUse::CooperativeMatrixUseMatrixBKHR) { + length = 16; // Assume factor layout + } else { + assert(use == CooperativeMatrixUse::CooperativeMatrixUseMatrixAccumulatorKHR); + length = 8; // Assume wave32 accumulation layout + } + return BM->addConstant(DestTy, length); } else { assert(DestTy->isTypeVector() || DestTy->isTypeScalar()); diff --git a/llpc/translator/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/llpc/translator/lib/SPIRV/libSPIRV/SPIRVInstruction.h index c5b52676a4..adb656f363 100644 --- a/llpc/translator/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/llpc/translator/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -649,6 +649,9 @@ class SPIRVBinary : public SPIRVInstTemplateBase { Op1Ty = getValueType(Op1)->getVectorComponentType(); } else if (getValueType(Op1)->isTypeMatrix()) { Op1Ty = getValueType(Op1)->getMatrixColumnType()->getVectorComponentType(); + } else if (getValueType(Op1)->isTypeCooperativeMatrixKHR()) { + assert((OpCode >= OpIAdd && OpCode <= OpFMod) || (OpCode == OpMatrixTimesScalar)); + Op1Ty = getValueType(Op1)->getCooperativeMatrixKHRComponentType(); } else { Op1Ty = getValueType(Op1); } @@ -657,6 +660,9 @@ class SPIRVBinary : public SPIRVInstTemplateBase { Op2Ty = getValueType(Op2)->getVectorComponentType(); } else if (getValueType(Op2)->isTypeMatrix()) { Op2Ty = getValueType(Op2)->getMatrixColumnType()->getVectorComponentType(); + } else if (getValueType(Op2)->isTypeCooperativeMatrixKHR()) { + assert(OpCode >= OpIAdd && OpCode <= OpFMod); + Op2Ty = getValueType(Op2)->getCooperativeMatrixKHRComponentType(); } else { Op2Ty = getValueType(Op2); } @@ -1302,7 +1308,12 @@ class SPIRVUnary : public SPIRVInstTemplateBase { if (isGenericNegateOpCode(OpCode)) { SPIRVType *ResTy = nullptr; SPIRVType *OpTy = nullptr; - { + + if (Type->isTypeCooperativeMatrixKHR() && + (static_cast(OpCode) == OpSNegate || static_cast(OpCode) == OpFNegate)) { + ResTy = Type->getCooperativeMatrixKHRComponentType(); + OpTy = getValueType(Op)->getCooperativeMatrixKHRComponentType(); + } else { ResTy = Type->isTypeVector() ? Type->getVectorComponentType() : Type; OpTy = Type->isTypeVector() ? getValueType(Op)->getVectorComponentType() : getValueType(Op); } @@ -1316,6 +1327,15 @@ class SPIRVUnary : public SPIRVInstTemplateBase { : 1) && "Invalid vector component Width for Generic Negate instruction"); } + if (Type->isTypeCooperativeMatrixKHR() && static_cast(OpCode) >= OpConvertFToU && + static_cast(OpCode) <= OpFConvert) { + SPIRVType *OpTy = getValueType(Op); + assert(OpTy->isTypeCooperativeMatrixKHR() && + Type->getCooperativeMatrixKHRScope() == OpTy->getCooperativeMatrixKHRScope() && + Type->getCooperativeMatrixKHRRows() == OpTy->getCooperativeMatrixKHRRows() && + Type->getCooperativeMatrixKHRColumns() == OpTy->getCooperativeMatrixKHRColumns()); + (void)OpTy; + } } }; @@ -1649,7 +1669,8 @@ class SPIRVCompositeExtract : public SPIRVInstruction { void validate() const override { SPIRVInstruction::validate(); assert(getValueType(Composite)->isTypeArray() || getValueType(Composite)->isTypeStruct() || - getValueType(Composite)->isTypeVector() || getValueType(Composite)->isTypeMatrix()); + getValueType(Composite)->isTypeVector() || getValueType(Composite)->isTypeMatrix() || + getValueType(Composite)->isTypeCooperativeMatrixKHR()); } SPIRVId Composite; std::vector Indices; @@ -2433,6 +2454,226 @@ SPIRVSpecConstantOp *createSpecConstantOpInst(SPIRVInstruction *Inst); SPIRVInstruction *createInstFromSpecConstantOp(SPIRVSpecConstantOp *C); SPIRVValue *createValueFromSpecConstantOp(SPIRVSpecConstantOp *Inst, uint32_t RoundingTypeMask); +// For KHR extension +class SPIRVCooperativeMatrixLoadKHR : public SPIRVInstruction, public SPIRVMemoryAccess { +public: + const static SPIRVWord FixedWords = 6; // To update when the stride is optional. + // Complete constructor + SPIRVCooperativeMatrixLoadKHR(SPIRVId TheId, SPIRVId PointerId, SPIRVId MemLayout, SPIRVId TheStrideId, + const std::vector &TheMemoryAccess, SPIRVBasicBlock *TheBB) + : SPIRVInstruction(FixedWords + TheMemoryAccess.size(), OpCooperativeMatrixLoadKHR, + TheBB->getValueType(PointerId)->getPointerElementType(), TheId, TheBB), + SPIRVMemoryAccess(TheMemoryAccess), PtrId(PointerId), ColMajorId(MemLayout), StrideId(TheStrideId), + MemoryAccess(TheMemoryAccess) { + validate(); + assert(TheBB && "Invalid BB"); + } + // Incomplete constructor + SPIRVCooperativeMatrixLoadKHR() + : SPIRVInstruction(OpCooperativeMatrixLoadKHR), SPIRVMemoryAccess(), PtrId(SPIRVID_INVALID), + ColMajorId(SPIRVID_INVALID), StrideId(SPIRVID_INVALID) {} + + SPIRVCapVec getRequiredCapability() const override { return getVec(CapabilityCooperativeMatrixKHR); } + + SPIRVValue *getSrc() const { return Module->get(PtrId); } + SPIRVValue *getColMajor() const { return Module->get(ColMajorId); } + SPIRVValue *getStride() const { return Module->get(StrideId); } + +protected: + void setWordCount(SPIRVWord TheWordCount) override { + SPIRVEntry::setWordCount(TheWordCount); + MemoryAccess.resize(TheWordCount - FixedWords); + } + + void decode(std::istream &I) override { + getDecoder(I) >> Type >> Id >> PtrId >> ColMajorId >> StrideId >> MemoryAccess; + memoryAccessUpdate(MemoryAccess); + } + + void validate() const override { + SPIRVInstruction::validate(); + assert(Type->isTypeCooperativeMatrixKHR() && "Invalid type"); + assert((getValueType(PtrId)->getPointerElementType()->isTypeScalar() || + getValueType(PtrId)->getPointerElementType()->isTypeVector()) && + "Invalid pointer type"); + assert((getValueType(PtrId)->getPointerStorageClass() == StorageClassWorkgroup || + getValueType(PtrId)->getPointerStorageClass() == StorageClassStorageBuffer || + getValueType(PtrId)->getPointerStorageClass() == StorageClassPhysicalStorageBuffer) && + "Invalid storage class of Pointer"); + assert(getValueType(StrideId)->isTypeInt() && "Invalid stride type"); + assert((getValue(ColMajorId)->getOpCode() == OpConstant) && "Invalid colmajor type"); + } + +private: + SPIRVId PtrId; + SPIRVId ColMajorId; + SPIRVId StrideId; + std::vector MemoryAccess; +}; + +class SPIRVCooperativeMatrixStoreKHR : public SPIRVInstruction, public SPIRVMemoryAccess { +public: + const static SPIRVWord FixedWords = 5; // To update when the stride is optional. + // Complete constructor + SPIRVCooperativeMatrixStoreKHR(SPIRVId PointerId, SPIRVId ValueId, SPIRVId TheStrideId, SPIRVId ColumnMajorId, + const std::vector &TheMemoryAccess, SPIRVBasicBlock *TheBB) + : SPIRVInstruction(FixedWords + TheMemoryAccess.size(), OpCooperativeMatrixStoreKHR, TheBB), + SPIRVMemoryAccess(TheMemoryAccess), PtrId(PointerId), ObjectId(ValueId), ColMajorId(ColumnMajorId), + StrideId(TheStrideId), MemoryAccess(TheMemoryAccess) { + setAttr(); + validate(); + assert(TheBB && "Invalid BB"); + } + // Incomplete constructor + SPIRVCooperativeMatrixStoreKHR() + : SPIRVInstruction(OpCooperativeMatrixStoreKHR), SPIRVMemoryAccess(), PtrId(SPIRVID_INVALID), + ObjectId(SPIRVID_INVALID), ColMajorId(SPIRVID_INVALID), StrideId(SPIRVID_INVALID) { + setAttr(); + } + + SPIRVCapVec getRequiredCapability() const override { return getVec(CapabilityCooperativeMatrixKHR); } + SPIRVValue *getDest() const { return Module->get(PtrId); } + SPIRVValue *getObject() const { return Module->get(ObjectId); } + SPIRVValue *getColMajor() const { return Module->get(ColMajorId); } + SPIRVValue *getStride() const { return Module->get(StrideId); } + +protected: + void setAttr() { + setHasNoType(); + setHasNoId(); + } + + void setWordCount(SPIRVWord TheWordCount) override { + SPIRVEntry::setWordCount(TheWordCount); + MemoryAccess.resize(TheWordCount - FixedWords); + } + + void decode(std::istream &I) override { + getDecoder(I) >> PtrId >> ObjectId >> ColMajorId >> StrideId >> MemoryAccess; + memoryAccessUpdate(MemoryAccess); + } + + void validate() const override { + SPIRVInstruction::validate(); + assert(getValueType(ObjectId)->isTypeCooperativeMatrixKHR() && "Invalid object type"); + assert((getValueType(PtrId)->getPointerElementType()->isTypeScalar() || + getValueType(PtrId)->getPointerElementType()->isTypeVector()) && + "Invalid pointer type"); + assert((getValueType(PtrId)->getPointerStorageClass() == StorageClassWorkgroup || + getValueType(PtrId)->getPointerStorageClass() == StorageClassStorageBuffer || + getValueType(PtrId)->getPointerStorageClass() == StorageClassPhysicalStorageBuffer) && + "Invalid storage class of Pointer"); + assert(getValueType(StrideId)->isTypeInt() && "Invalid stride type"); + assert((getValue(ColMajorId)->getOpCode() == OpConstant) && "Invalid colmajor type"); + } + +private: + SPIRVId PtrId; + SPIRVId ObjectId; + SPIRVId ColMajorId; + SPIRVId StrideId; + std::vector MemoryAccess; +}; + +class SPIRVCooperativeMatrixLengthKHR : public SPIRVInstruction { +public: + const static SPIRVWord FixedWords = 4; + // Complete constructor + SPIRVCooperativeMatrixLengthKHR(SPIRVId TypeId, SPIRVId TheId, SPIRVId TheMatrixTypeId, SPIRVBasicBlock *TheBB) + : SPIRVInstruction(FixedWords, OpCooperativeMatrixLengthKHR, TheBB->get(TypeId), TheId, TheBB), + MatrixTypeId(TheMatrixTypeId) { + validate(); + assert(TheBB && "Invalid BB"); + } + // Incomplete constructor + SPIRVCooperativeMatrixLengthKHR() : SPIRVInstruction(OpCooperativeMatrixLengthKHR), MatrixTypeId(SPIRVID_INVALID) {} + + SPIRVCapVec getRequiredCapability() const override { return getVec(CapabilityCooperativeMatrixKHR); } + SPIRVType *getMatrixType() const { return Module->get(MatrixTypeId); } + +protected: + void decode(std::istream &I) override { getDecoder(I) >> Type >> Id >> MatrixTypeId; } + + void validate() const override { + SPIRVInstruction::validate(); + assert(getMatrixType()->isTypeCooperativeMatrixKHR() && "Invalid type"); + assert(Type->isTypeInt() && Type->getBitWidth() == 32 && !static_cast(Type)->isSigned() && + "Invalid result type"); + } + +private: + SPIRVId MatrixTypeId; +}; + +class SPIRVCooperativeMatrixKHRInstBase : public SPIRVInstTemplateBase { +public: + SPIRVCapVec getRequiredCapability() const override { return getVec(CapabilityCooperativeMatrixKHR); } + + SPIRVWord getMatrixASigned() const { + if (WordCount == 7) { + return Ops[3] & CooperativeMatrixOperandsMatrixASignedComponentsKHRMask; + } else { + return 0; + } + } + + SPIRVWord getMatrixBSigned() const { + if (WordCount == 7) { + return Ops[3] & CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask; + } else { + return 0; + } + } + + SPIRVWord getMatrixCSigned() const { + if (WordCount == 7) { + return Ops[3] & CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask; + } else { + return 0; + } + } + + SPIRVWord getMatrixResultSigned() const { + if (WordCount == 7) { + return Ops[3] & CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask; + } else { + return 0; + } + } + + SPIRVWord getMatrixSatAccumulation() const { + if (WordCount == 7) { + return Ops[3] & CooperativeMatrixOperandsSaturatingAccumulationKHRMask; + } else { + return 0; + } + } + +protected: + void validate() const override { + assert(getOpCode() == OpCooperativeMatrixMulAddKHR); + SPIRVType *AType = getValueType(Ops[0]); + SPIRVType *BType = getValueType(Ops[1]); + SPIRVType *CType = getValueType(Ops[2]); + assert(Type->getCooperativeMatrixKHRRows() == CType->getCooperativeMatrixKHRRows() && + Type->getCooperativeMatrixKHRColumns() == CType->getCooperativeMatrixKHRColumns() && + Type->getCooperativeMatrixKHRRows() == AType->getCooperativeMatrixKHRRows() && + Type->getCooperativeMatrixKHRColumns() == BType->getCooperativeMatrixKHRColumns() && + AType->getCooperativeMatrixKHRColumns() == BType->getCooperativeMatrixKHRRows() && "Inconsistent MxKxN"); + assert(Type->isTypeCooperativeMatrixKHR() && AType->isTypeCooperativeMatrixKHR() && + BType->isTypeCooperativeMatrixKHR() && CType->isTypeCooperativeMatrixKHR() && "Invalid A/B/C/D type"); + assert(Type->getCooperativeMatrixKHRScope() == AType->getCooperativeMatrixKHRScope() && + Type->getCooperativeMatrixKHRScope() == BType->getCooperativeMatrixKHRScope() && + Type->getCooperativeMatrixKHRScope() == CType->getCooperativeMatrixKHRScope() && "Inconsistent scope"); + (void)AType; + (void)BType; + (void)CType; + } +}; +#define _SPIRV_OP(x, ...) typedef SPIRVInstTemplate SPIRV##x; +_SPIRV_OP(CooperativeMatrixMulAddKHR, true, 6, true, 3) +#undef _SPIRV_OP + } // namespace SPIRV #endif diff --git a/llpc/translator/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h b/llpc/translator/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h index bada1bc1c8..90836b1f17 100644 --- a/llpc/translator/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h +++ b/llpc/translator/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h @@ -499,6 +499,7 @@ template <> inline void SPIRVMap::init() { add(CapabilityWorkgroupMemoryExplicitLayoutKHR, "WorkgroupMemoryExplicitLayoutKHR"); add(CapabilityWorkgroupMemoryExplicitLayout8BitAccessKHR, "WorkgroupMemoryExplicitLayout8BitAccessKHR"); add(CapabilityWorkgroupMemoryExplicitLayout16BitAccessKHR, "WorkgroupMemoryExplicitLayout16BitAccessKHR"); + add(CapabilityCooperativeMatrixKHR, "CooperativeMatrixKHR"); } SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap) diff --git a/llpc/translator/lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h b/llpc/translator/lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h index 0067944e03..9172c7ad9a 100644 --- a/llpc/translator/lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h +++ b/llpc/translator/lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h @@ -326,6 +326,11 @@ _SPIRV_OP(TraceNV, 5337) _SPIRV_OP(RayQueryGetIntersectionTriangleVertexPositionsKHR, 5340) _SPIRV_OP(TypeAccelerationStructureKHR, 5341) _SPIRV_OP(ExecuteCallableKHR, 5344) +_SPIRV_OP(TypeCooperativeMatrixKHR, 4456) +_SPIRV_OP(CooperativeMatrixLoadKHR, 4457) +_SPIRV_OP(CooperativeMatrixStoreKHR, 4458) +_SPIRV_OP(CooperativeMatrixMulAddKHR, 4459) +_SPIRV_OP(CooperativeMatrixLengthKHR, 4460) _SPIRV_OP(DemoteToHelperInvocationEXT, 5380) _SPIRV_OP(IsHelperInvocationEXT, 5381) _SPIRV_OP(SubgroupShuffleINTEL, 5571) diff --git a/llpc/translator/lib/SPIRV/libSPIRV/SPIRVType.cpp b/llpc/translator/lib/SPIRV/libSPIRV/SPIRVType.cpp index 8e2a70d5e2..757c754f58 100644 --- a/llpc/translator/lib/SPIRV/libSPIRV/SPIRVType.cpp +++ b/llpc/translator/lib/SPIRV/libSPIRV/SPIRVType.cpp @@ -191,6 +191,31 @@ SPIRVWord SPIRVType::getCompositeElementCount() const { return 1; } +SPIRVType *SPIRVType::getCooperativeMatrixKHRComponentType() const { + assert(OpCode == OpTypeCooperativeMatrixKHR && "Not cooperative matrix type"); + return static_cast(this)->getComponentType(); +} + +uint32_t SPIRVType::getCooperativeMatrixKHRScope() const { + assert(OpCode == OpTypeCooperativeMatrixKHR && "Not cooperative matrix type"); + return static_cast(this)->getScope()->getZExtIntValue(); +} + +uint32_t SPIRVType::getCooperativeMatrixKHRRows() const { + assert(OpCode == OpTypeCooperativeMatrixKHR && "Not cooperative matrix type"); + return static_cast(this)->getRows()->getZExtIntValue(); +} + +uint32_t SPIRVType::getCooperativeMatrixKHRColumns() const { + assert(OpCode == OpTypeCooperativeMatrixKHR && "Not cooperative matrix type"); + return static_cast(this)->getColumns()->getZExtIntValue(); +} + +uint32_t SPIRVType::getCooperativeMatrixKHRUse() const { + assert(OpCode == OpTypeCooperativeMatrixKHR && "Not cooperative matrix type"); + return static_cast(this)->getUse()->getZExtIntValue(); +} + bool SPIRVType::isTypeVoid() const { return OpCode == OpTypeVoid; } @@ -263,6 +288,10 @@ bool SPIRVType::isTypeRayQueryKHR() const { return OpCode == OpTypeRayQueryKHR; } +bool SPIRVType::isTypeCooperativeMatrixKHR() const { + return OpCode == OpTypeCooperativeMatrixKHR; +} + bool SPIRVType::isTypeVectorBool() const { return isTypeVector() && getVectorComponentType()->isTypeBool(); } @@ -329,4 +358,41 @@ void SPIRVTypeForwardPointer::decode(std::istream &I) { Decoder >> Id >> SC; } +void SPIRVTypeCooperativeMatrixKHR::validate() const { + SPIRVEntry::validate(); + CompType->validate(); + assert(CompType->isTypeInt() || CompType->isTypeFloat()); + assert(isa(getValue(Rows)) || isa(getValue(Rows))); + assert(getValue(Rows)->getType()->isTypeInt()); + assert(isa(getValue(Columns)) || isa(getValue(Columns))); + assert(getValue(Columns)->getType()->isTypeInt()); + assert(isa(getValue(Use)) || isa(getValue(Use))); + assert(getValue(Use)->getType()->isTypeInt()); + // CompIntp is still under dicussion: + // assert(isa(getValue(CompIntp)) || isa(getValue(CompIntp))); + // assert(getValue(CompIntp)->getType()->isTypeInt()); +} + +SPIRVConstant *SPIRVTypeCooperativeMatrixKHR::getScope() const { + return get(Scope); +} + +SPIRVConstant *SPIRVTypeCooperativeMatrixKHR::getRows() const { + return get(Rows); +} + +SPIRVConstant *SPIRVTypeCooperativeMatrixKHR::getColumns() const { + return get(Columns); +} + +SPIRVConstant *SPIRVTypeCooperativeMatrixKHR::getUse() const { + return get(Use); +} + +SPIRVConstant *SPIRVTypeCooperativeMatrixKHR::getComIntp() const { + return get(CompIntp); +} + +_SPIRV_IMP_DECODE6(SPIRVTypeCooperativeMatrixKHR, Id, CompType, Scope, Rows, Columns, Use) + } // namespace SPIRV diff --git a/llpc/translator/lib/SPIRV/libSPIRV/SPIRVType.h b/llpc/translator/lib/SPIRV/libSPIRV/SPIRVType.h index 729f2e46b4..b5709c1007 100644 --- a/llpc/translator/lib/SPIRV/libSPIRV/SPIRVType.h +++ b/llpc/translator/lib/SPIRV/libSPIRV/SPIRVType.h @@ -79,6 +79,11 @@ class SPIRVType : public SPIRVEntry { SPIRVType *getMatrixColumnType() const; SPIRVType *getCompositeElementType(size_t) const; SPIRVWord getCompositeElementCount() const; + SPIRVType *getCooperativeMatrixKHRComponentType() const; + uint32_t getCooperativeMatrixKHRScope() const; + uint32_t getCooperativeMatrixKHRRows() const; + uint32_t getCooperativeMatrixKHRColumns() const; + uint32_t getCooperativeMatrixKHRUse() const; bool isTypeVoid() const; bool isTypeArray() const; bool isTypeRuntimeArray() const; @@ -103,6 +108,7 @@ class SPIRVType : public SPIRVEntry { bool isTypeVectorOrScalarBool() const; bool isTypeAccelerationStructureKHR() const; bool isTypeRayQueryKHR() const; + bool isTypeCooperativeMatrixKHR() const; }; class SPIRVTypeVoid : public SPIRVType { @@ -638,6 +644,55 @@ class SPIRVTypeRayQueryKHR : public SPIRVType { _SPIRV_DEF_DECODE1(Id) }; +class SPIRVTypeCooperativeMatrixKHR : public SPIRVType { +public: + // Compile constructor + SPIRVTypeCooperativeMatrixKHR(SPIRVModule *M, SPIRVId TheId, SPIRVType *TheCompType, SPIRVId TheScope, + SPIRVId TheRows, SPIRVId TheColumns, SPIRVId TheUse, SPIRVId TheCompIntp) + : SPIRVType(M, 7, OpTypeCooperativeMatrixKHR, TheId), CompType(TheCompType), Scope(TheScope), Rows(TheRows), + Columns(TheColumns), Use(TheUse), CompIntp(TheCompIntp) { + validate(); + } + // Incomplete constructor + SPIRVTypeCooperativeMatrixKHR() + : SPIRVType(OpTypeCooperativeMatrixKHR), CompType(nullptr), Scope(ScopeSubgroup), Rows(0), Columns(0), Use(0), + CompIntp(0) {} + + SPIRVType *getComponentType() const { return CompType; } + SPIRVConstant *getScope() const; + SPIRVConstant *getRows() const; + SPIRVConstant *getColumns() const; + SPIRVConstant *getUse() const; + SPIRVConstant *getComIntp() const; + SPIRVCapVec getRequiredCapability() const override { + SPIRVCapVec V(getComponentType()->getRequiredCapability()); + V.push_back(CapabilityCooperativeMatrixKHR); + return V; + } + std::vector getNonLiteralOperands() const override { + std::vector Operands(6); + Operands[0] = CompType; + Operands[1] = (SPIRVEntry *)getScope(); + Operands[2] = (SPIRVEntry *)getRows(); + Operands[3] = (SPIRVEntry *)getColumns(); + Operands[4] = (SPIRVEntry *)getUse(); + Operands[5] = (SPIRVEntry *)getComIntp(); + return Operands; + } + +protected: + _SPIRV_DCL_DECODE + void validate() const override; + +private: + SPIRVType *CompType; // Component Type + SPIRVId Scope; // The scope all invocations belonging to + SPIRVId Rows; // The matrix row number + SPIRVId Columns; // The matrix column number + SPIRVId Use; // The matrix use: A/B/C + SPIRVId CompIntp; // Specifies how Component Type is interpreted +}; + template bool isType(const T1 *Ty, unsigned Bits = 0) { bool Is = Ty->getOpCode() == T2::OC; if (!Is) diff --git a/shared/continuations/CMakeLists.txt b/shared/continuations/CMakeLists.txt index 74812c1ea2..6c97f64589 100644 --- a/shared/continuations/CMakeLists.txt +++ b/shared/continuations/CMakeLists.txt @@ -11,10 +11,13 @@ function(set_compiler_options PROJECT_NAME) endif() endfunction() +option(CONTINUATIONS_BUILD_TESTS "Build continuation tests") + add_llvm_library(LLVMContinuations lib/CleanupContinuations.cpp lib/ContinuationsDialect.cpp lib/ContinuationsUtil.cpp + lib/CpsStackLowering.cpp lib/DXILCont.cpp lib/DXILContIntrinsicPrepare.cpp lib/DXILContLgcRtOpConverter.cpp @@ -53,11 +56,13 @@ target_include_directories(LLVMContinuations PUBLIC $ ) -target_link_libraries(LLVMContinuations PUBLIC llvm_dialects) +llvm_map_components_to_libnames(extra_llvm_libs CompilerUtils) + +target_link_libraries(LLVMContinuations PUBLIC ${extra_llvm_libs}) set_compiler_options(LLVMContinuations) # TableGen for dialects -set(CONTINUATIONS_TABLEGEN_EXE llvm-dialects-tblgen) +set(CONTINUATIONS_TABLEGEN_EXE $) set(CONTINUATIONS_TABLEGEN_TARGET llvm-dialects-tblgen) macro(cont_tablegen DIALECTNAME FILE OUTPUT_FILENAME) @@ -82,7 +87,7 @@ target_compile_features(LLVMContinuations PUBLIC cxx_std_17) set_target_properties(LLVMContinuations PROPERTIES CXX_EXTENSIONS OFF) add_subdirectory(plugin) -add_subdirectory(test) -if(LLPC_BUILD_TESTS) +if(CONTINUATIONS_BUILD_TESTS) + add_subdirectory(test) add_subdirectory(unittests) endif() diff --git a/shared/continuations/include/continuations/Continuations.h b/shared/continuations/include/continuations/Continuations.h index c6059fc6c4..e61defe6c7 100644 --- a/shared/continuations/include/continuations/Continuations.h +++ b/shared/continuations/include/continuations/Continuations.h @@ -205,6 +205,12 @@ CallInst *replaceIntrinsicCall(IRBuilder<> &B, Type *SystemDataTy, Value *SystemData, DXILShaderKind Kind, CallInst *Call); +/// Transformations that run early on the driver/gpurt module. +/// +/// Replace intrinsics called by gpurt code that can be replaced early. +/// Returns whether something changed. +bool earlyDriverTransform(Module &M); + /// Buffered pointers use a fixed number of registers, and fall back to an /// allocation if the registers to not suffice to contain the content. Given a /// number NumI32s of 4-byte values and the number of reserved registers, return @@ -298,6 +304,7 @@ class LegacyCleanupContinuationsPass Function *RestoreContState; Function *RegisterBufferSetPointerBarrier; Function *Continue; + Function *WaitContinue; Function *Complete; GlobalVariable *ContState; MapVector ToProcess; diff --git a/shared/continuations/include/continuations/ContinuationsUtil.h b/shared/continuations/include/continuations/ContinuationsUtil.h index 42bfb233f7..2dc5396423 100644 --- a/shared/continuations/include/continuations/ContinuationsUtil.h +++ b/shared/continuations/include/continuations/ContinuationsUtil.h @@ -34,7 +34,7 @@ #include "lgc/LgcCpsDialect.h" #include "lgc/LgcRtDialect.h" -#include "llvm-dialects/Dialect/OpDescription.h" +#include "llvm-dialects/Dialect/OpMap.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" @@ -47,6 +47,7 @@ #include #include #include +#include namespace DialectUtils { @@ -54,21 +55,6 @@ llvm::StringRef getLgcRtDialectOpName(llvm::StringRef FullName); bool isLgcRtOp(const llvm::Function *F); -template bool isDialectOpDeclaration(llvm::Function &F) { - static llvm_dialects::OpDescription Desc = - llvm_dialects::OpDescription::get(); - return Desc.matchDeclaration(F); -} - -template -bool isAnyDialectOpDeclaration(llvm::Function &F) { - return (isDialectOpDeclaration(F) || ...); -} - -template -bool isNoneOfDialectOpDeclaration(llvm::Function &F) { - return (!isDialectOpDeclaration(F) && ...); -} } // namespace DialectUtils namespace llvm { @@ -120,7 +106,7 @@ struct GpuRtIntrinsicEntry { bool AccessesHitData = false; }; -extern StringMap LgcRtGpuRtMap; +extern const OpMap LgcRtGpuRtMap; // This must match DXIL::ShaderKind from DxilConstants.h, and also // DXILShaderKind in a matching definition in GPURT, because it is used @@ -268,8 +254,18 @@ class DXILContHelper { // limit, we spill to the continuation stack. static constexpr const char *MDMaxPayloadRegisterCountName = "continuation.maxPayloadRegisterCount"; + // The address space used to store the continuations stack. + // The possible values for this metadata are the values of ContStackAddrspace. static constexpr const char *MDStackAddrspaceName = "continuation.stackAddrspace"; + // The raytracing ip level that is available on the target architecture. + // This is exposed to gpurt code via the GetRtip intrinsic. + static constexpr const char *MDRtipName = "continuation.rtip"; + // Flags set for continuations. + // This is exposed to gpurt code via the ContinuationsGetFlags intrinsic. + static constexpr const char *MDFlagsName = "continuation.flags"; + // Marks an await as a waiting one with a wait mask. + static constexpr const char *MDIsWaitAwaitName = "continuation.wait.await"; // Function-scope metadata for payload and hit attribute size limits, // referring to the app-defined structs only. @@ -324,6 +320,8 @@ class DXILContHelper { static constexpr const char *GlobalPayloadName = "PAYLOAD"; static constexpr const char *GlobalContStateName = "CONTINUATION_STATE"; static constexpr const char *GlobalRegistersName = "REGISTERS"; + static constexpr ContStackAddrspace DefaultStackAddrspace = + ContStackAddrspace::Scratch; static void RegisterPasses(llvm::PassBuilder &PB, bool NeedDialectContext); @@ -488,6 +486,32 @@ class DXILContHelper { static_cast(StackAddrspace))); } + static std::optional tryGetRtip(const Module &M) { + auto *MD = M.getNamedMetadata(MDRtipName); + if (!MD) + return {}; + return extractZExtI32Constant(MD->getOperand(0)); + }; + + static void setRtip(Module &M, uint32_t RtipLevel) { + auto *MD = M.getOrInsertNamedMetadata(MDRtipName); + MD->clearOperands(); + MD->addOperand(getI32MDConstant(M.getContext(), RtipLevel)); + } + + static std::optional tryGetFlags(const Module &M) { + auto *MD = M.getNamedMetadata(MDFlagsName); + if (!MD) + return {}; + return extractZExtI32Constant(MD->getOperand(0)); + }; + + static void setFlags(Module &M, uint32_t Flags) { + auto *MD = M.getOrInsertNamedMetadata(MDFlagsName); + MD->clearOperands(); + MD->addOperand(getI32MDConstant(M.getContext(), Flags)); + } + static void setContinuationStateByteCount(Function &F, uint32_t ByteCount) { F.setMetadata(MDStateName, getI32MDConstant(F.getContext(), ByteCount)); } @@ -523,6 +547,21 @@ class DXILContHelper { return Mod.getNamedMetadata(MDLgcCpsModule) != nullptr; } + // Specifies that an awaited call should wait on a wait mask. + static void setIsWaitAwaitCall(CallInst &CI) { + CI.setMetadata(DXILContHelper::MDIsWaitAwaitName, + MDTuple::get(CI.getContext(), {})); + } + + // Queries whether an awaited call should wait on a wait mask. + static bool isWaitAwaitCall(const CallInst &CI) { + return CI.getMetadata(MDIsWaitAwaitName) != nullptr; + } + + static void removeIsWaitAwaitMetadata(CallInst &CI) { + CI.setMetadata(DXILContHelper::MDIsWaitAwaitName, nullptr); + } + static DXILShaderKind shaderStageToDxilShaderKind(lgc::rt::RayTracingShaderStage Stage) { switch (Stage) { @@ -568,18 +607,18 @@ class DXILContHelper { /// Free-standing helpers. -// A little helper function that allows to apply a callback on the users (calls) -// of a function. -void forEachCall(Function &F, const std::function &Callback); - -// A little helper function that allows to apply a callback on the users (calls) -// of a set of functions given by iterating over a module. -void forEachCall(Module &M, const std::function &Callback); - -// A little helper function that allows to apply a callback on the users (calls) -// of a set of functions. -void forEachCall(ArrayRef Funcs, - const std::function &Callback); +// Helper to visit all calls of a function. +// Expected type for Callback: +// void(CallInst &) +template +void forEachCall(Function &F, CallbackTy Callback) { + static_assert(std::is_invocable_v); + for (auto &Use : make_early_inc_range(F.uses())) { + if (auto *CInst = dyn_cast(Use.getUser())) + if (CInst->isCallee(&Use)) + Callback(*CInst); + } +} // Move all basic blocks of OldFunc to NewFunc. void moveFunctionBody(Function &OldFunc, Function &NewFunc); diff --git a/shared/continuations/include/continuations/CpsStackLowering.h b/shared/continuations/include/continuations/CpsStackLowering.h new file mode 100644 index 0000000000..a0f65207ee --- /dev/null +++ b/shared/continuations/include/continuations/CpsStackLowering.h @@ -0,0 +1,94 @@ +/* + *********************************************************************************************************************** + * + * Copyright (c) 2020-2023 Advanced Micro Devices, Inc. All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + *deal in the Software without restriction, including without limitation the + *rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + *sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + *all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + *FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + *IN THE SOFTWARE. + * + **********************************************************************************************************************/ +/** + *********************************************************************************************************************** + * @file CpsStackLowering.h + * @brief Contains declaration of class CpsStackLowering. + *********************************************************************************************************************** + */ +#pragma once + +#include "compilerutils/TypeLowering.h" +#include "lgc/LgcCpsDialect.h" +#include "llvm/ADT/SmallVector.h" + +namespace llvm { +class LLVMContext; +class Function; +class GetElementPtrInst; +class PtrToIntInst; +class IntToPtrInst; +class LoadInst; +class StoreInst; +class Module; +class Value; +class Type; +class DataLayout; +} // namespace llvm + +using namespace lgc::cps; + +constexpr unsigned ContinuationStackAlignment = 4; + +class CpsStackLowering { +public: + CpsStackLowering(llvm::LLVMContext &Context, + unsigned LoweredCpsStackAddrSpace) + : TypeLower(Context), LoweredCpsStackAddrSpace{LoweredCpsStackAddrSpace} { + } + void lowerCpsStackOps(llvm::Function &, llvm::Value *); + // Get continuation stack size (in bytes). + unsigned getStackSizeInBytes() { return StackSizeInBytes; } + + inline unsigned getLoweredCpsStackAddrSpace() const { + return LoweredCpsStackAddrSpace; + } + + inline unsigned + getLoweredCpsStackPointerSize(const llvm::DataLayout &Layout) { + return Layout.getPointerSize(LoweredCpsStackAddrSpace); + } + + TypeLowering TypeLower; + +private: + llvm::SmallVector convertCpsStackPointer(TypeLowering &, + llvm::Type *); + void visitCpsAlloc(AllocOp &); + void visitCpsFree(FreeOp &); + void visitCpsPeek(PeekOp &); + void visitSetVsp(SetVspOp &); + void visitGetVsp(GetVspOp &); + void visitGetElementPtr(llvm::GetElementPtrInst &); + void visitPtrToIntInst(llvm::PtrToIntInst &); + void visitIntToPtrInst(llvm::IntToPtrInst &); + void visitLoad(llvm::LoadInst &); + void visitStore(llvm::StoreInst &); + + llvm::Module *Mod; + llvm::Value *CpsStackAlloca; + unsigned LoweredCpsStackAddrSpace; + unsigned StackSizeInBytes = 0; +}; diff --git a/shared/continuations/include/continuations/LowerRaytracingPipeline.h b/shared/continuations/include/continuations/LowerRaytracingPipeline.h index d22ece5fce..4ea976fbd5 100644 --- a/shared/continuations/include/continuations/LowerRaytracingPipeline.h +++ b/shared/continuations/include/continuations/LowerRaytracingPipeline.h @@ -74,15 +74,7 @@ class ModuleMetadataState final { return MinPayloadRegisterCount; } - ContStackAddrspace getContStackAddrspace() const { return StackAddrspace; }; - - bool isGlobalAddressSpace() const { - return StackAddrspace == ContStackAddrspace::Global; - } - - [[maybe_unused]] bool isScratchAddressSpace() const { - return StackAddrspace == ContStackAddrspace::Scratch; - } + ContStackAddrspace getContStackAddrspace() const { return StackAddrspace; } void updateModuleMetadata() const; @@ -93,13 +85,11 @@ class ModuleMetadataState final { static constexpr uint32_t DefaultPayloadRegisterCount = 30; /// Maximum allowed number of registers to be used for the payload. uint32_t MaxPayloadRegisterCount = 0; - //// Minimum required number of payload registers. + /// Minimum required number of payload registers. uint32_t MinPayloadRegisterCount = 0; /// The address space used for the continuations stack. /// Either stack or global memory. - ContStackAddrspace StackAddrspace = DefaultStackAddrspace; - static constexpr ContStackAddrspace DefaultStackAddrspace = - ContStackAddrspace::Scratch; + ContStackAddrspace StackAddrspace = DXILContHelper::DefaultStackAddrspace; }; class CpsMutator final { @@ -214,8 +204,6 @@ class LowerRaytracingPipelinePassImpl final { void replaceShaderIndexCall(IRBuilder<> &B, FunctionData &Data, CallInst *Call); - void handleContinuationStackIsGlobal(Function &Func); - void handleGetFuncAddr(Function &Func); void handleAmdInternalFunc(Function &Func); @@ -224,8 +212,6 @@ class LowerRaytracingPipelinePassImpl final { void collectDriverFunctions(); - void handleGetUninitialized(Function &Func); - // Copy the payload content between global payload and local payload. // Excludes the stack pointer or hit attributes which may also reside in // payload storage. If Stage is not set, all fields in SerializationInfo are diff --git a/shared/continuations/lib/ContinuationsUtil.cpp b/shared/continuations/lib/ContinuationsUtil.cpp index 4b32acb50e..419b710a53 100644 --- a/shared/continuations/lib/ContinuationsUtil.cpp +++ b/shared/continuations/lib/ContinuationsUtil.cpp @@ -31,33 +31,41 @@ //===----------------------------------------------------------------------===// #include "continuations/ContinuationsUtil.h" +#include "lgc/LgcRtDialect.h" #include "llvm/ADT/STLExtras.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" -llvm::StringMap llvm::LgcRtGpuRtMap = { - {"instance.id", {"InstanceID", true}}, - {"instance.index", {"InstanceIndex", true}}, - {"hit.kind", {"HitKind", true}}, - {"ray.flags", {"RayFlags", false}}, - {"dispatch.rays.index", {"DispatchRaysIndex3", false}}, - {"dispatch.rays.dimensions", {"DispatchRaysDimensions3", false}}, - {"world.ray.origin", {"WorldRayOrigin3", false}}, - {"world.ray.direction", {"WorldRayDirection3", false}}, - {"object.ray.origin", {"ObjectRayOrigin3", true}}, - {"object.ray.direction", {"ObjectRayDirection3", true}}, - {"object.to.world", {"ObjectToWorld4x3", true}}, - {"world.to.object", {"WorldToObject4x3", true}}, - {"ray.tmin", {"RayTMin", false}}, - {"ray.tcurrent", {"RayTCurrent", true}}, - {"ignore.hit", {"IgnoreHit", false}}, - {"accept.hit.and.end.search", {"AcceptHitAndEndSearch", false}}, - {"trace.ray", {"TraceRay", false}}, - {"report.hit", {"ReportHit", false}}, - {"call.callable.shader", {"CallShader", false}}, - {"primitive.index", {"PrimitiveIndex", true}}, - {"geometry.index", {"GeometryIndex", true}}, -}; +#define GPURTMAP_ENTRY(Op, GpurtName, AccessesHitData) \ + { \ + OpDescription::get(), { GpurtName, AccessesHitData } \ + } + +const OpMap llvm::LgcRtGpuRtMap = {{ + GPURTMAP_ENTRY(InstanceIdOp, "InstanceID", true), + GPURTMAP_ENTRY(InstanceIndexOp, "InstanceIndex", true), + GPURTMAP_ENTRY(HitKindOp, "HitKind", true), + GPURTMAP_ENTRY(RayFlagsOp, "RayFlags", false), + GPURTMAP_ENTRY(DispatchRaysIndexOp, "DispatchRaysIndex3", false), + GPURTMAP_ENTRY(DispatchRaysDimensionsOp, "DispatchRaysDimensions3", false), + GPURTMAP_ENTRY(WorldRayOriginOp, "WorldRayOrigin3", false), + GPURTMAP_ENTRY(WorldRayDirectionOp, "WorldRayDirection3", false), + GPURTMAP_ENTRY(ObjectRayOriginOp, "ObjectRayOrigin3", true), + GPURTMAP_ENTRY(ObjectRayDirectionOp, "ObjectRayDirection3", true), + GPURTMAP_ENTRY(ObjectToWorldOp, "ObjectToWorld4x3", true), + GPURTMAP_ENTRY(WorldToObjectOp, "WorldToObject4x3", true), + GPURTMAP_ENTRY(RayTminOp, "RayTMin", false), + GPURTMAP_ENTRY(RayTcurrentOp, "RayTCurrent", true), + GPURTMAP_ENTRY(IgnoreHitOp, "IgnoreHit", false), + GPURTMAP_ENTRY(AcceptHitAndEndSearchOp, "AcceptHitAndEndSearch", false), + GPURTMAP_ENTRY(TraceRayOp, "TraceRay", false), + GPURTMAP_ENTRY(ReportHitOp, "ReportHit", false), + GPURTMAP_ENTRY(CallCallableShaderOp, "CallShader", false), + GPURTMAP_ENTRY(PrimitiveIndexOp, "PrimitiveIndex", true), + GPURTMAP_ENTRY(GeometryIndexOp, "GeometryIndex", true), +}}; + +#undef GPURTMAP_ENTRY llvm::StringRef DialectUtils::getLgcRtDialectOpName(llvm::StringRef FullName) { return FullName.substr(std::strlen("lgc.rt.")); @@ -67,31 +75,6 @@ bool DialectUtils::isLgcRtOp(const llvm::Function *F) { return F && F->getName().starts_with("lgc.rt"); } -// A small wrapper around that allows to apply a callback on the users (calls) -// of a function -void llvm::forEachCall(Function &F, - const std::function &Callback) { - for (auto &Use : make_early_inc_range(F.uses())) { - if (auto *CInst = dyn_cast(Use.getUser())) - if (CInst->isCallee(&Use)) - Callback(*CInst); - } -} - -void llvm::forEachCall(Module &M, - const std::function &Callback) { - for (auto &Func : M) { - forEachCall(Func, Callback); - } -} - -void llvm::forEachCall(ArrayRef Funcs, - const std::function &Callback) { - for (auto *Func : Funcs) { - forEachCall(*Func, Callback); - } -} - void llvm::moveFunctionBody(Function &OldFunc, Function &NewFunc) { while (!OldFunc.empty()) { BasicBlock *BB = &OldFunc.front(); @@ -105,13 +88,11 @@ llvm::findIntrImplEntryByIntrinsicCall(CallInst *Call) { if (!DialectUtils::isLgcRtOp(Call->getCalledFunction())) return std::nullopt; - auto Name = Call->getCalledFunction()->getName(); - auto ImplEntry = - LgcRtGpuRtMap.find(DialectUtils::getLgcRtDialectOpName(Name)); + auto ImplEntry = LgcRtGpuRtMap.find(*Call); if (ImplEntry == LgcRtGpuRtMap.end()) report_fatal_error("Unhandled lgc.rt op!"); - return ImplEntry->second; + return *ImplEntry.val(); } bool llvm::removeUnusedFunctionDecls(Module *Mod, bool OnlyIntrinsics) { diff --git a/shared/continuations/lib/CpsStackLowering.cpp b/shared/continuations/lib/CpsStackLowering.cpp new file mode 100644 index 0000000000..615815492a --- /dev/null +++ b/shared/continuations/lib/CpsStackLowering.cpp @@ -0,0 +1,262 @@ +/* + *********************************************************************************************************************** + * + * Copyright (c) 2020-2023 Advanced Micro Devices, Inc. All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + *deal in the Software without restriction, including without limitation the + *rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + *sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + *all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + *FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + *IN THE SOFTWARE. + * + **********************************************************************************************************************/ + +#include "continuations/CpsStackLowering.h" +#include "llvm-dialects/Dialect/Visitor.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Type.h" +#include + +using namespace llvm; +using namespace lgc::cps; + +LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD(CpsStackLowering, TypeLower) + +// ===================================================================================================================== +// Type lowering rule that lowers cps stack pointer type to corresponding +// backend pointer type. +// +// @param typeLowering : the calling TypeLowering object +// @param type : the type to be converted +SmallVector +CpsStackLowering::convertCpsStackPointer(TypeLowering &TypeLower, Type *Ty) { + SmallVector Types; + + if (auto *PtrTy = dyn_cast(Ty)) { + if (PtrTy->getAddressSpace() == lgc::cps::stackAddrSpace) + Types.push_back( + PointerType::get(Ty->getContext(), LoweredCpsStackAddrSpace)); + } + + return Types; +} + +// ===================================================================================================================== +// Lower continuation stack operations in the function +// +// @param Function : the function to be processed +// @param CpsStorage : the alloca used for the holding the latest continuation +// stack pointer +void CpsStackLowering::lowerCpsStackOps(Function &Function, Value *CpsStorage) { + Mod = Function.getParent(); + StackSizeInBytes = 0; + CpsStackAlloca = CpsStorage; + TypeLower.addRule(std::bind(&CpsStackLowering::convertCpsStackPointer, this, + std::placeholders::_1, std::placeholders::_2)); + auto *NewFunc = &Function; + if (lgc::cps::isCpsFunction(Function)) + NewFunc = TypeLower.lowerFunctionArguments(Function); + + static const auto Visitor = llvm_dialects::VisitorBuilder() + .nest(&TypeLowering::registerVisitors) + .add(&CpsStackLowering::visitCpsAlloc) + .add(&CpsStackLowering::visitCpsFree) + .add(&CpsStackLowering::visitCpsPeek) + .add(&CpsStackLowering::visitSetVsp) + .add(&CpsStackLowering::visitGetVsp) + .add(&CpsStackLowering::visitGetElementPtr) + .add(&CpsStackLowering::visitPtrToIntInst) + .add(&CpsStackLowering::visitIntToPtrInst) + .add(&CpsStackLowering::visitLoad) + .add(&CpsStackLowering::visitStore) + .build(); + Visitor.visit(*this, *NewFunc); + TypeLower.finishPhis(); + TypeLower.finishCleanup(); +} + +// ===================================================================================================================== +// Lower getelementptr instruction +// +// @param function : the instruction +void CpsStackLowering::visitGetElementPtr(GetElementPtrInst &GEP) { + if (GEP.getAddressSpace() != lgc::cps::stackAddrSpace) + return; + + IRBuilder<> Builder(&GEP); + + SmallVector Indices(GEP.idx_begin(), GEP.idx_end()); + + Value *NewGEP = nullptr; + auto Values = TypeLower.getValue(GEP.getPointerOperand()); + auto *GEPVal = Values[0]; + auto *GEPTy = GEP.getSourceElementType(); + + if (GEP.isInBounds()) + NewGEP = Builder.CreateInBoundsGEP(GEPTy, GEPVal, Indices); + else + NewGEP = Builder.CreateGEP(GEPTy, GEPVal, Indices); + + cast(NewGEP)->copyMetadata(GEP); + + TypeLower.replaceInstruction(&GEP, {NewGEP}); +} + +// ===================================================================================================================== +// Lower load instruction +// +// @param function : the instruction +void CpsStackLowering::visitLoad(LoadInst &Load) { + if (Load.getPointerAddressSpace() != lgc::cps::stackAddrSpace) + return; + + auto Values = TypeLower.getValue(Load.getPointerOperand()); + Load.replaceUsesOfWith(Load.getPointerOperand(), Values[0]); +} + +// ===================================================================================================================== +// Lower store instruction +// +// @param function : the instruction +void CpsStackLowering::visitStore(llvm::StoreInst &Store) { + if (Store.getPointerAddressSpace() != lgc::cps::stackAddrSpace) + return; + + auto Values = TypeLower.getValue(Store.getPointerOperand()); + Store.replaceUsesOfWith(Store.getPointerOperand(), Values[0]); +} + +// ===================================================================================================================== +// Lower ptrtoint instruction +// +// @param function : the instruction +void CpsStackLowering::visitPtrToIntInst(llvm::PtrToIntInst &Ptr2Int) { + if (Ptr2Int.getPointerAddressSpace() != lgc::cps::stackAddrSpace) + return; + + auto Values = TypeLower.getValue(Ptr2Int.getOperand(0)); + Ptr2Int.replaceUsesOfWith(Ptr2Int.getOperand(0), Values[0]); +} + +// ===================================================================================================================== +// Lower inttoptr instruction +// +// @param function : the instruction +void CpsStackLowering::visitIntToPtrInst(llvm::IntToPtrInst &Int2Ptr) { + if (Int2Ptr.getAddressSpace() != lgc::cps::stackAddrSpace) + return; + + IRBuilder<> Builder(&Int2Ptr); + auto *NewPtr = Builder.CreateIntToPtr( + Int2Ptr.getOperand(0), + PointerType::get(Builder.getContext(), LoweredCpsStackAddrSpace)); + TypeLower.replaceInstruction(&Int2Ptr, NewPtr); +} + +// ===================================================================================================================== +// Lower lgc.cps.alloc instruction +// +// @param function : the instruction +void CpsStackLowering::visitCpsAlloc(lgc::cps::AllocOp &AllocOp) { + IRBuilder<> Builder(&AllocOp); + + Value *Size = AllocOp.getSize(); + const DataLayout &Layout = Mod->getDataLayout(); + Value *VSP = Builder.CreateAlignedLoad( + Builder.getPtrTy(LoweredCpsStackAddrSpace), CpsStackAlloca, + Align(getLoweredCpsStackPointerSize(Layout))); + unsigned AlignedSize = alignTo(cast(Size)->getZExtValue(), + ContinuationStackAlignment); + StackSizeInBytes += AlignedSize; + + // update stack pointer + Value *Ptr = + Builder.CreateConstGEP1_32(Builder.getInt8Ty(), VSP, AlignedSize); + Builder.CreateAlignedStore(Ptr, CpsStackAlloca, + Align(getLoweredCpsStackPointerSize(Layout))); + + TypeLower.replaceInstruction(&AllocOp, {VSP}); +} + +// ===================================================================================================================== +// Lower lgc.cps.free instruction +// +// @param function : the instruction +void CpsStackLowering::visitCpsFree(lgc::cps::FreeOp &FreeOp) { + IRBuilder<> Builder(&FreeOp); + const DataLayout &Layout = Mod->getDataLayout(); + + Value *VSP = Builder.CreateAlignedLoad( + Builder.getPtrTy(LoweredCpsStackAddrSpace), CpsStackAlloca, + Align(getLoweredCpsStackPointerSize(Layout))); + Value *Size = FreeOp.getSize(); + unsigned AlignedSize = alignTo(cast(Size)->getZExtValue(), + ContinuationStackAlignment); + Value *Ptr = + Builder.CreateConstGEP1_32(Builder.getInt8Ty(), VSP, -AlignedSize); + // Assuming continuation stack grows upward. + Builder.CreateAlignedStore(Ptr, CpsStackAlloca, + Align(getLoweredCpsStackPointerSize(Layout))); + TypeLower.replaceInstruction(&FreeOp, {}); +} + +// ===================================================================================================================== +// Lower lgc.cps.peek instruction +// +// @param function : the instruction +void CpsStackLowering::visitCpsPeek(lgc::cps::PeekOp &PeekOp) { + IRBuilder<> Builder(&PeekOp); + const DataLayout &Layout = Mod->getDataLayout(); + + auto *Ptr = Builder.CreateAlignedLoad( + Builder.getPtrTy(LoweredCpsStackAddrSpace), CpsStackAlloca, + Align(getLoweredCpsStackPointerSize(Layout))); + auto *Size = PeekOp.getSize(); + unsigned ImmSize = cast(Size)->getZExtValue(); + ImmSize = alignTo(ImmSize, ContinuationStackAlignment); + // Assuming continuation stack grows upward. + auto *Result = + Builder.CreateGEP(Builder.getInt8Ty(), Ptr, {Builder.getInt32(-ImmSize)}); + TypeLower.replaceInstruction(&PeekOp, {Result}); +} + +// ===================================================================================================================== +// Lower lgc.cps.set.VSP instruction +// +// @param function : the instruction +void CpsStackLowering::visitSetVsp(lgc::cps::SetVspOp &SetVsp) { + IRBuilder<> B(&SetVsp); + const DataLayout &Layout = Mod->getDataLayout(); + + auto *Ptr = SetVsp.getPtr(); + auto Converted = TypeLower.getValue(Ptr); + B.CreateAlignedStore(Converted[0], CpsStackAlloca, + Align(getLoweredCpsStackPointerSize(Layout))); + TypeLower.replaceInstruction(&SetVsp, {}); +} + +// ===================================================================================================================== +// Lower lgc.cps.get.VSP instruction +// +// @param function : the instruction +void CpsStackLowering::visitGetVsp(lgc::cps::GetVspOp &GetVsp) { + IRBuilder<> B(&GetVsp); + const DataLayout &Layout = Mod->getDataLayout(); + + auto *Ptr = + B.CreateAlignedLoad(B.getPtrTy(LoweredCpsStackAddrSpace), CpsStackAlloca, + Align(getLoweredCpsStackPointerSize(Layout))); + TypeLower.replaceInstruction(&GetVsp, {Ptr}); +} diff --git a/shared/continuations/lib/DXILCont.cpp b/shared/continuations/lib/DXILCont.cpp index b04eaa3c5b..690f5c4196 100644 --- a/shared/continuations/lib/DXILCont.cpp +++ b/shared/continuations/lib/DXILCont.cpp @@ -649,6 +649,136 @@ CallInst *llvm::replaceIntrinsicCall(IRBuilder<> &B, Type *SystemDataTy, return NewCall; } +/// Transform enqueue intrinsics to continuation intrinsics +static void replaceEnqueueIntrinsic(Function &F, Function *NewFunc) { + for (auto &Use : make_early_inc_range(F.uses())) { + if (auto *CInst = dyn_cast(Use.getUser())) { + if (CInst->isCallee(&Use)) { + IRBuilder<> B(CInst); + SmallVector Args(CInst->args()); + bool IsEnqueue = F.getName().contains("Enqueue"); + // Add the current function as return address to the call. + // Used when Traversal calls AnyHit or Intersection. + if (IsEnqueue && F.getName().contains("EnqueueCall")) { + bool HasWaitMask = F.getName().contains("WaitEnqueue"); + auto *RetAddr = + B.CreatePtrToInt(CInst->getFunction(), B.getInt64Ty()); + Args.insert(Args.begin() + (HasWaitMask ? 3 : 2), RetAddr); + } + + B.CreateCall(NewFunc, Args); + CInst->eraseFromParent(); + } + } + } +} + +static void handleContinuationStackIsGlobal(Function &Func, + ContStackAddrspace StackAddrspace) { + assert(Func.arg_empty() + // bool + && Func.getFunctionType()->getReturnType()->isIntegerTy(1)); + + auto *IsGlobal = ConstantInt::getBool( + Func.getContext(), StackAddrspace == ContStackAddrspace::Global); + + llvm::forEachCall(Func, [&](llvm::CallInst &CInst) { + CInst.replaceAllUsesWith(IsGlobal); + CInst.eraseFromParent(); + }); +} + +static void handleContinuationsGetFlags(Function &Func, uint32_t Flags) { + assert(Func.arg_empty() + // i32 + && Func.getFunctionType()->getReturnType()->isIntegerTy(32)); + + auto *FlagsConst = + ConstantInt::get(IntegerType::get(Func.getContext(), 32), Flags); + + llvm::forEachCall(Func, [&](llvm::CallInst &CInst) { + CInst.replaceAllUsesWith(FlagsConst); + CInst.eraseFromParent(); + }); +} + +static void handleGetRtip(Function &Func, uint32_t RtipLevel) { + assert(Func.arg_empty() + // i32 + && Func.getFunctionType()->getReturnType()->isIntegerTy(32)); + + auto *RtipConst = + ConstantInt::get(IntegerType::get(Func.getContext(), 32), RtipLevel); + for (auto &Use : make_early_inc_range(Func.uses())) { + if (auto *CInst = dyn_cast(Use.getUser())) { + if (CInst->isCallee(&Use)) { + CInst->replaceAllUsesWith(RtipConst); + CInst->eraseFromParent(); + } + } + } +} + +static void handleGetUninitialized(Function &Func) { + auto *ArgTy = Func.getReturnType(); + auto *Poison = PoisonValue::get(ArgTy); + llvm::forEachCall(Func, [&](llvm::CallInst &CInst) { + CInst.replaceAllUsesWith(Poison); + CInst.eraseFromParent(); + }); +} + +bool llvm::earlyDriverTransform(Module &M) { + // Import StackAddrspace from metadata if set, otherwise from default + auto StackAddrspaceMD = DXILContHelper::tryGetStackAddrspace(M); + auto StackAddrspace = + StackAddrspaceMD.value_or(DXILContHelper::DefaultStackAddrspace); + + // Import from metadata if set + auto RtipLevel = DXILContHelper::tryGetRtip(M); + auto Flags = DXILContHelper::tryGetFlags(M); + + bool Changed = false; + // Replace Enqueue and Complete intrinsics + for (auto &F : M) { + Function *Replacement = nullptr; + auto Name = F.getName(); + if (Name.contains("WaitEnqueue")) + Replacement = getContinuationWaitContinue(M); + else if (Name.contains("Enqueue")) + Replacement = getContinuationContinue(M); + else if (Name.contains("Complete")) + Replacement = getContinuationComplete(M); + + if (Replacement) { + Changed = true; + replaceEnqueueIntrinsic(F, Replacement); + } + + if (Name == "_AmdContinuationStackIsGlobal") { + Changed = true; + handleContinuationStackIsGlobal(F, StackAddrspace); + } else if (Name == "_AmdContinuationsGetFlags") { + Changed = true; + if (!Flags) + report_fatal_error("Tried to get continuation flags but it is not " + "available on the module"); + handleContinuationsGetFlags(F, *Flags); + } else if (Name == "_AmdGetRtip") { + Changed = true; + if (!RtipLevel) + report_fatal_error( + "Tried to get rtip level but it is not available on the module"); + handleGetRtip(F, *RtipLevel); + } else if (Name.startswith("_AmdGetUninitialized")) { + Changed = true; + handleGetUninitialized(F); + } + } + + return Changed; +} + uint64_t llvm::computeNeededStackSizeForRegisterBuffer(uint64_t NumI32s, uint64_t NumReservedRegisters) { diff --git a/shared/continuations/lib/DXILContIntrinsicPrepare.cpp b/shared/continuations/lib/DXILContIntrinsicPrepare.cpp index b8e13c4e95..ef8b6aae08 100644 --- a/shared/continuations/lib/DXILContIntrinsicPrepare.cpp +++ b/shared/continuations/lib/DXILContIntrinsicPrepare.cpp @@ -258,30 +258,6 @@ static Function *transformFunction(Function &F) { return NewFunc; } -/// Transform enqueue intrinsics to continuation intrinsics -static void replaceIntrinsic(Function &F, Function *NewFunc) { - for (auto &Use : make_early_inc_range(F.uses())) { - if (auto *CInst = dyn_cast(Use.getUser())) { - if (CInst->isCallee(&Use)) { - IRBuilder<> B(CInst); - SmallVector Args(CInst->args()); - bool IsEnqueue = F.getName().contains("Enqueue"); - // Add the current function as return address to the call. - // Used when Traversal calls AnyHit or Intersection. - if (IsEnqueue && F.getName().contains("EnqueueCall")) { - bool HasWaitMask = F.getName().contains("WaitEnqueue"); - auto *RetAddr = - B.CreatePtrToInt(CInst->getFunction(), B.getInt64Ty()); - Args.insert(Args.begin() + (HasWaitMask ? 3 : 2), RetAddr); - } - - B.CreateCall(NewFunc, Args); - CInst->eraseFromParent(); - } - } - } -} - static bool isGpuRtFuncName(StringRef Name) { for (const auto &Intr : LgcRtGpuRtMap) { if (Name.contains(Intr.second.Name)) @@ -308,6 +284,7 @@ static bool isUtilFunction(StringRef Name) { "GetFuncAddr", "GetLocalRootIndex", "GetResumePointAddr", + "GetRtip", "GetShaderKind", "GetTriangleHitAttributes", "GetUninitialized", @@ -343,22 +320,9 @@ llvm::PreservedAnalyses DXILContIntrinsicPreparePass::run( transformFunction(*F); } - // Recollect functions as they may have been replaced - for (auto &F : M.functions()) { - Function *Replacement = nullptr; - auto Name = F.getName(); - if (Name.contains("WaitEnqueue")) - Replacement = getContinuationWaitContinue(M); - else if (Name.contains("Enqueue")) - Replacement = getContinuationContinue(M); - else if (Name.contains("Complete")) - Replacement = getContinuationComplete(M); - - if (Replacement) - replaceIntrinsic(F, Replacement); - } - fixupDxilMetadata(M); + earlyDriverTransform(M); + return PreservedAnalyses::none(); } diff --git a/shared/continuations/lib/LegacyCleanupContinuations.cpp b/shared/continuations/lib/LegacyCleanupContinuations.cpp index b004efd169..775e0c7b79 100644 --- a/shared/continuations/lib/LegacyCleanupContinuations.cpp +++ b/shared/continuations/lib/LegacyCleanupContinuations.cpp @@ -524,14 +524,23 @@ void LegacyCleanupContinuationsPass::handleSingleContinue( auto *Csp = B.CreateLoad(CpsType, B.CreateCall(CspFun)); - // Replace this instruction with a call to continuation.continue + bool IsWait = DXILContHelper::isWaitAwaitCall(*Call); + Function *ContinueFunction = IsWait ? WaitContinue : Continue; + + // Replace this instruction with a call to continuation.[wait]continue SmallVector Args; Args.push_back(B.CreatePointerCast(Call->getCalledOperand(), I64)); + // The wait mask is the first argument after the function pointer + if (IsWait) + Args.push_back(*Call->arg_begin()); Args.push_back(Csp); Args.push_back(ReturnAddrInt); - Args.append(Call->arg_begin(), Call->arg_end()); - auto *ContinueCall = B.CreateCall(Continue, Args); + Args.append(Call->arg_begin() + (IsWait ? 1 : 0), Call->arg_end()); + auto *ContinueCall = B.CreateCall(ContinueFunction, Args); + // Copy metadata, except for the wait flag, which is no longer needed. ContinueCall->copyMetadata(*Call); + if (IsWait) + DXILContHelper::removeIsWaitAwaitMetadata(*ContinueCall); assert(DXILContHelper::tryGetOutgoingRegisterCount(ContinueCall) && "Missing registercount metadata!"); @@ -637,6 +646,7 @@ llvm::PreservedAnalyses LegacyCleanupContinuationsPass::run( SaveContState = getContinuationSaveContinuationState(Mod); RestoreContState = getContinuationRestoreContinuationState(Mod); Continue = getContinuationContinue(Mod); + WaitContinue = getContinuationWaitContinue(Mod); Complete = getContinuationComplete(Mod); // Add global diff --git a/shared/continuations/lib/LowerRaytracingPipeline.cpp b/shared/continuations/lib/LowerRaytracingPipeline.cpp index 591b46acda..137c6e4fbf 100644 --- a/shared/continuations/lib/LowerRaytracingPipeline.cpp +++ b/shared/continuations/lib/LowerRaytracingPipeline.cpp @@ -45,6 +45,7 @@ #include "continuations/ContinuationsUtil.h" #include "continuations/PayloadAccessQualifiers.h" #include "lgc/LgcRtDialect.h" +#include "llvm-dialects/Dialect/OpSet.h" #include "llvm-dialects/Dialect/Visitor.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" @@ -236,7 +237,8 @@ ModuleMetadataState::ModuleMetadataState(Module &Module) : Mod{Module} { // Import StackAddrspace from metadata if set, otherwise from default auto StackAddrspaceMD = DXILContHelper::tryGetStackAddrspace(Module); - StackAddrspace = StackAddrspaceMD.value_or(DefaultStackAddrspace); + StackAddrspace = + StackAddrspaceMD.value_or(DXILContHelper::DefaultStackAddrspace); } /// Write the previously derived information about max payload registers and @@ -331,8 +333,9 @@ bool llvm::isRematerializableLgcRtOp(CallInst &CInst, return false; // Always rematerialize - if (DialectUtils::isAnyDialectOpDeclaration(*Callee)) + static const OpSet RematerializableDialectOps = + OpSet::get(); + if (RematerializableDialectOps.contains(*Callee)) return true; // Rematerialize for Intersection that can only call ReportHit, which keeps @@ -341,11 +344,12 @@ bool llvm::isRematerializableLgcRtOp(CallInst &CInst, // information is lost from the system data struct. Also exclude rayTCurrent // because ReportHit calls can change that. if (!Kind || *Kind == DXILShaderKind::Intersection) { - if (DialectUtils::isAnyDialectOpDeclaration< - InstanceIdOp, InstanceIndexOp, GeometryIndexOp, - ObjectRayDirectionOp, ObjectRayOriginOp, ObjectToWorldOp, - PrimitiveIndexOp, RayFlagsOp, RayTminOp, WorldRayDirectionOp, - WorldRayOriginOp, WorldToObjectOp>(*Callee)) + static const OpSet RematerializableIntersectionDialectOps = + OpSet::get(); + if (RematerializableIntersectionDialectOps.contains(*Callee)) return true; } @@ -792,6 +796,11 @@ void LowerRaytracingPipelinePassImpl::replaceContinuationCall( MetadataState.getMaxPayloadRegisterCount())); DXILContHelper::setReturnedRegisterCount( Token, ContinuationStateRegisterCount + ReturnedRegisterCount.value()); + + // For WaitAwait, add metadata indicating that we wait. After coroutine + // passes, we then generate a waitContinue on the awaited function. + if (Call->getCalledFunction()->getName().startswith("_AmdWaitAwait")) + DXILContHelper::setIsWaitAwaitCall(*Token); } if (PassedPayload) { @@ -866,21 +875,6 @@ void LowerRaytracingPipelinePassImpl::replaceShaderIndexCall(IRBuilder<> &B, Call->eraseFromParent(); } -void LowerRaytracingPipelinePassImpl::handleContinuationStackIsGlobal( - Function &Func) { - assert(Func.arg_empty() - // bool - && Func.getFunctionType()->getReturnType()->isIntegerTy(1)); - - auto *IsGlobal = - ConstantInt::getBool(*Context, MetadataState.isGlobalAddressSpace()); - - llvm::forEachCall(Func, [&](llvm::CallInst &CInst) { - CInst.replaceAllUsesWith(IsGlobal); - CInst.eraseFromParent(); - }); -} - void LowerRaytracingPipelinePassImpl::handleGetFuncAddr(Function &Func) { assert(Func.arg_empty() // returns i64 or i32 @@ -904,15 +898,6 @@ void LowerRaytracingPipelinePassImpl::handleGetFuncAddr(Function &Func) { }); } -void LowerRaytracingPipelinePassImpl::handleGetUninitialized(Function &Func) { - auto *ArgTy = Func.getReturnType(); - auto *Poison = PoisonValue::get(ArgTy); - llvm::forEachCall(Func, [&](llvm::CallInst &CInst) { - CInst.replaceAllUsesWith(Poison); - CInst.eraseFromParent(); - }); -} - void llvm::copyBytes(IRBuilder<> &B, Value *Dst, Value *Src, uint64_t NumBytes) { assert(Dst->getType()->isPointerTy() && Src->getType()->isPointerTy() && @@ -1863,7 +1848,8 @@ void LowerRaytracingPipelinePassImpl::handleDriverFuncAssertions() { void LowerRaytracingPipelinePassImpl::handleAmdInternalFunc(Function &Func) { StringRef FuncName = Func.getName(); - if (FuncName.starts_with("_AmdAwait")) { + if (FuncName.starts_with("_AmdAwait") || + FuncName.starts_with("_AmdWaitAwait")) { Awaits.push_back(&Func); assert(!Func.arg_empty() // Function address @@ -1878,17 +1864,9 @@ void LowerRaytracingPipelinePassImpl::handleAmdInternalFunc(Function &Func) { && Func.getFunctionType()->getParamType(0)->isPointerTy()); } - if (FuncName == "_AmdContinuationStackIsGlobal") { - handleContinuationStackIsGlobal(Func); - } - if (FuncName.starts_with("_AmdGetFuncAddr")) { handleGetFuncAddr(Func); } - - if (FuncName.starts_with("_AmdGetUninitialized")) { - handleGetUninitialized(Func); - } } // Search for known intrinsics that cannot be rematerialized @@ -1897,14 +1875,11 @@ void LowerRaytracingPipelinePassImpl::handleUnrematerializableCandidates() { if (!DialectUtils::isLgcRtOp(&Func)) continue; - if (DialectUtils::isNoneOfDialectOpDeclaration( - Func)) { + static const OpSet NonRematerializableDialectOps = + OpSet::get(); + if (!NonRematerializableDialectOps.contains(Func)) { llvm::forEachCall(Func, [&](llvm::CallInst &CInst) { - // shader.index is handled separately - if (isa(CInst)) - return; - auto Data = ToProcess.find(CInst.getFunction()); if (Data != ToProcess.end()) { if (!isRematerializableLgcRtOp(CInst, Data->second.Kind)) @@ -1951,64 +1926,52 @@ bool LowerRaytracingPipelinePassImpl::run() { static const auto Visitor = llvm_dialects::VisitorBuilder() .setStrategy(llvm_dialects::VisitorStrategy::ByInstruction) - .add([](VisitorState &State, TraceRayOp &TraceRay) { - auto Data = State.Processables.find(TraceRay.getFunction()); + .addSet([](VisitorState &State, Instruction &Op) { + auto *CInst = cast(&Op); + auto Data = State.Processables.find(CInst->getFunction()); if (Data == State.Processables.end()) return; + if (isa(Op)) { + Data->second.ShaderIndexCalls.push_back(CInst); + return; + } + Type *PayloadTy = - DXILContHelper::getPayloadTypeFromMetadata(TraceRay); - PAQPayloadConfig PAQPayload = { - PayloadTy, Data->second.FuncConfig.MaxHitAttributeBytes}; - uint32_t PayloadStorageI32s = - State.PAQManager.getMaxPayloadStorageI32sForTraceRayFunc( - PAQPayload); - Data->second.MaxOutgoingPayloadI32s = std::max( - Data->second.MaxOutgoingPayloadI32s, PayloadStorageI32s); - - Data->second.TraceRayCalls.push_back(&TraceRay); - }) - .add( - [](VisitorState &State, - CallCallableShaderOp &CallCallableShader) { - auto Data = - State.Processables.find(CallCallableShader.getFunction()); - if (Data == State.Processables.end()) - return; - - Type *PayloadTy = DXILContHelper::getPayloadTypeFromMetadata( - CallCallableShader); - PAQPayloadConfig PAQPayload = { - PayloadTy, Data->second.FuncConfig.MaxHitAttributeBytes}; - uint32_t PayloadStorageI32s = - State.PAQManager.getMaxPayloadStorageI32sForCallShaderFunc( + DXILContHelper::getPayloadTypeFromMetadata(*CInst); + + if (!isa(Op)) { + PAQPayloadConfig PAQPayload = { + PayloadTy, Data->second.FuncConfig.MaxHitAttributeBytes}; + + uint32_t PayloadStorageI32s = 0; + if (isa(Op)) { + PayloadStorageI32s = + State.PAQManager.getMaxPayloadStorageI32sForTraceRayFunc( PAQPayload); - Data->second.MaxOutgoingPayloadI32s = std::max( - Data->second.MaxOutgoingPayloadI32s, PayloadStorageI32s); - - Data->second.CallShaderCalls.push_back(&CallCallableShader); - }) - .add([](VisitorState &State, auto &ReportHitOp) { - // The converter uses payload type metadata also to indicate hit - // attribute types - auto HitAttributesTy = - DXILContHelper::getPayloadTypeFromMetadata(ReportHitOp); - auto Data = State.Processables.find(ReportHitOp.getFunction()); - if (Data == State.Processables.end()) - return; - assert((!Data->second.HitAttributes || - Data->second.HitAttributes == HitAttributesTy) && - "Multiple reportHit calls with different hit attributes"); - Data->second.HitAttributes = HitAttributesTy; - Data->second.ReportHitCalls.push_back(&ReportHitOp); - }) - .add([](VisitorState &State, auto &ShaderIndexOp) { - auto Data = State.Processables.find(ShaderIndexOp.getFunction()); - if (Data == State.Processables.end()) - return; + Data->second.TraceRayCalls.push_back(CInst); + } else if (isa(Op)) { + PayloadStorageI32s = + State.PAQManager.getMaxPayloadStorageI32sForCallShaderFunc( + PAQPayload); - Data->second.ShaderIndexCalls.push_back(&ShaderIndexOp); + Data->second.CallShaderCalls.push_back(CInst); + } + + Data->second.MaxOutgoingPayloadI32s = std::max( + Data->second.MaxOutgoingPayloadI32s, PayloadStorageI32s); + } else { + // The converter uses payload type metadata also to indicate hit + // attribute types + assert((!Data->second.HitAttributes || + Data->second.HitAttributes == PayloadTy) && + "Multiple reportHit calls with different hit attributes"); + Data->second.HitAttributes = PayloadTy; + + Data->second.ReportHitCalls.push_back(CInst); + } }) .build(); @@ -2043,18 +2006,22 @@ bool LowerRaytracingPipelinePassImpl::run() { // Handle places after Awaits where system data is restored IRBuilder<> B(*Context); - llvm::forEachCall(RestoreSystemDatas, [&](llvm::CallInst &CInst) { - B.SetInsertPoint(&CInst); - handleRestoreSystemData(B, &CInst); - }); + for (llvm::Function *Func : RestoreSystemDatas) { + llvm::forEachCall(*Func, [&](llvm::CallInst &CInst) { + B.SetInsertPoint(&CInst); + handleRestoreSystemData(B, &CInst); + }); + } // Change specialized functions afterwards, so the payload or hit attributes // exist as the last argument - llvm::forEachCall(Awaits, [&](llvm::CallInst &CInst) { - auto Data = AwaitsToProcess.find(CInst.getFunction()); - if (Data != AwaitsToProcess.end()) - Data->second.AwaitCalls.push_back(&CInst); - }); + for (llvm::Function *Func : Awaits) { + llvm::forEachCall(*Func, [&](llvm::CallInst &CInst) { + auto Data = AwaitsToProcess.find(CInst.getFunction()); + if (Data != AwaitsToProcess.end()) + Data->second.AwaitCalls.push_back(&CInst); + }); + } for (auto &FuncData : AwaitsToProcess) { for (auto *Call : FuncData.second.AwaitCalls) { diff --git a/shared/continuations/test/dx/dxil-cont-prepare-traversal.ll b/shared/continuations/test/dx/dxil-cont-prepare-traversal.ll index 3ad887f99b..384e985aed 100644 --- a/shared/continuations/test/dx/dxil-cont-prepare-traversal.ll +++ b/shared/continuations/test/dx/dxil-cont-prepare-traversal.ll @@ -80,29 +80,28 @@ attributes #2 = { nounwind } ; PREPARE-NEXT: [[TMP3:%.*]] = load i32, ptr [[TMP2]], align 4 ; PREPARE-NEXT: [[TMP4:%.*]] = icmp eq i32 [[TMP3]], 0 ; PREPARE-NEXT: [[TMP5:%.*]] = getelementptr inbounds [[STRUCT_TRAVERSALDATA]], ptr [[TMP1]], i32 0, i32 0 -; PREPARE-NEXT: br i1 [[TMP4]], label [[TMP13:%.*]], label [[TMP6:%.*]] +; PREPARE-NEXT: br i1 [[TMP4]], label [[TMP12:%.*]], label [[TMP6:%.*]] ; PREPARE: 6: -; PREPARE-NEXT: [[TMP7:%.*]] = call i1 @_AmdContinuationStackIsGlobal() -; PREPARE-NEXT: [[TMP8:%.*]] = call i32 @_AmdContPayloadRegistersI32Count() -; PREPARE-NEXT: [[TMP9:%.*]] = call i32 @_AmdContPayloadRegistersGetI32(i32 0) +; PREPARE-NEXT: [[TMP7:%.*]] = call i32 @_AmdContPayloadRegistersI32Count() +; PREPARE-NEXT: [[TMP8:%.*]] = call i32 @_AmdContPayloadRegistersGetI32(i32 0) ; PREPARE-NEXT: call void @_AmdContPayloadRegistersSetI32(i32 0, i32 1) -; PREPARE-NEXT: [[TMP10:%.*]] = call i32 @_AmdValueI32CountSomething(ptr [[TMP1]]) -; PREPARE-NEXT: [[TMP11:%.*]] = call i32 @_AmdValueGetI32Something(ptr [[TMP1]], i32 0) +; PREPARE-NEXT: [[TMP9:%.*]] = call i32 @_AmdValueI32CountSomething(ptr [[TMP1]]) +; PREPARE-NEXT: [[TMP10:%.*]] = call i32 @_AmdValueGetI32Something(ptr [[TMP1]], i32 0) ; PREPARE-NEXT: call void @_AmdValueSetI32Something(ptr [[TMP1]], i32 0, i32 1) -; PREPARE-NEXT: [[A0:%.*]] = zext i1 [[TMP7]] to i32 -; PREPARE-NEXT: [[A1:%.*]] = add i32 [[A0]], [[TMP8]] -; PREPARE-NEXT: [[A2:%.*]] = add i32 [[A1]], [[TMP9]] -; PREPARE-NEXT: [[A3:%.*]] = add i32 [[A2]], [[TMP10]] -; PREPARE-NEXT: [[A4:%.*]] = add i32 [[A3]], [[TMP11]] +; PREPARE-NEXT: [[A0:%.*]] = zext i1 false to i32 +; PREPARE-NEXT: [[A1:%.*]] = add i32 [[A0]], [[TMP7]] +; PREPARE-NEXT: [[A2:%.*]] = add i32 [[A1]], [[TMP8]] +; PREPARE-NEXT: [[A3:%.*]] = add i32 [[A2]], [[TMP9]] +; PREPARE-NEXT: [[A4:%.*]] = add i32 [[A3]], [[TMP10]] ; PREPARE-NEXT: [[ADDR:%.*]] = zext i32 [[A4]] to i64 -; PREPARE-NEXT: [[TMP12:%.*]] = load [[STRUCT_SYSTEMDATA:%.*]], ptr [[TMP5]], align 4 -; PREPARE-NEXT: call void (i64, i64, ...) @continuation.waitContinue(i64 [[ADDR]], i64 -1, i32 [[STACKPTR]], i64 ptrtoint (ptr @_AmdTraversal to i64), [[STRUCT_SYSTEMDATA]] [[TMP12]]) -; PREPARE-NEXT: br label [[TMP15:%.*]] -; PREPARE: 13: -; PREPARE-NEXT: [[TMP14:%.*]] = load [[STRUCT_SYSTEMDATA]], ptr [[TMP5]], align 4 -; PREPARE-NEXT: call void (i64, i64, ...) @continuation.waitContinue(i64 0, i64 -1, i32 [[STACKPTR]], [[STRUCT_SYSTEMDATA]] [[TMP14]]) -; PREPARE-NEXT: br label [[TMP15]] -; PREPARE: 15: +; PREPARE-NEXT: [[TMP11:%.*]] = load [[STRUCT_SYSTEMDATA:%.*]], ptr [[TMP5]], align 4 +; PREPARE-NEXT: call void (i64, i64, ...) @continuation.waitContinue(i64 [[ADDR]], i64 -1, i32 [[STACKPTR]], i64 ptrtoint (ptr @_AmdTraversal to i64), [[STRUCT_SYSTEMDATA]] [[TMP11]]) +; PREPARE-NEXT: br label [[TMP14:%.*]] +; PREPARE: 12: +; PREPARE-NEXT: [[TMP13:%.*]] = load [[STRUCT_SYSTEMDATA]], ptr [[TMP5]], align 4 +; PREPARE-NEXT: call void (i64, i64, ...) @continuation.waitContinue(i64 0, i64 -1, i32 [[STACKPTR]], [[STRUCT_SYSTEMDATA]] [[TMP13]]) +; PREPARE-NEXT: br label [[TMP14]] +; PREPARE: 14: ; PREPARE-NEXT: ret void ; ; diff --git a/shared/continuations/test/dx/intrinsics/continuation-stack-is-global-false.ll b/shared/continuations/test/dx/intrinsics/continuation-stack-is-global-false.ll index c92501260b..fbcfd0e4b6 100644 --- a/shared/continuations/test/dx/intrinsics/continuation-stack-is-global-false.ll +++ b/shared/continuations/test/dx/intrinsics/continuation-stack-is-global-false.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 -; RUN: opt --verify-each -passes='dxil-cont-lgc-rt-op-converter,lint,lower-raytracing-pipeline,lint' -S %s 2>%t.stderr | FileCheck %s +; RUN: opt --verify-each -passes='dxil-cont-lgc-rt-op-converter,lint,dxil-cont-intrinsic-prepare,lint' -S %s 2>%t.stderr | FileCheck %s ; RUN: count 0 < %t.stderr %struct.DispatchSystemData = type { i32 } @@ -13,15 +13,10 @@ declare %struct.DispatchSystemData @_cont_SetupRayGen() declare !types !8 i32 @_cont_GetLocalRootIndex(%struct.DispatchSystemData*) define void @main() { -; CHECK-LABEL: define void @main() !lgc.rt.shaderstage !6 !continuation.entry !12 !continuation.registercount !6 !continuation !13 { +; CHECK-LABEL: define void @main() !lgc.rt.shaderstage !5 { ; CHECK-NEXT: entry: -; CHECK-NEXT: [[SYSTEM_DATA_ALLOCA:%.*]] = alloca [[STRUCT_DISPATCHSYSTEMDATA:%.*]], align 8 -; CHECK-NEXT: [[TMP0:%.*]] = call [[STRUCT_DISPATCHSYSTEMDATA]] @continuations.getSystemData.s_struct.DispatchSystemDatas() -; CHECK-NEXT: store [[STRUCT_DISPATCHSYSTEMDATA]] [[TMP0]], ptr [[SYSTEM_DATA_ALLOCA]], align 4 -; CHECK-NEXT: [[LOCAL_ROOT_INDEX:%.*]] = call i32 @_cont_GetLocalRootIndex(ptr [[SYSTEM_DATA_ALLOCA]]) -; CHECK-NEXT: call void @amd.dx.setLocalRootIndex(i32 [[LOCAL_ROOT_INDEX]]) ; CHECK-NEXT: store i1 false, ptr @debug_global, align 1 -; CHECK-NEXT: ret void, !continuation.registercount !9 +; CHECK-NEXT: ret void ; entry: %val = call i1 @_AmdContinuationStackIsGlobal() diff --git a/shared/continuations/test/dx/intrinsics/continuation-stack-is-global-true.ll b/shared/continuations/test/dx/intrinsics/continuation-stack-is-global-true.ll index 9626baf7d9..84e4a8e49b 100644 --- a/shared/continuations/test/dx/intrinsics/continuation-stack-is-global-true.ll +++ b/shared/continuations/test/dx/intrinsics/continuation-stack-is-global-true.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 -; RUN: opt --verify-each -passes='dxil-cont-lgc-rt-op-converter,lint,lower-raytracing-pipeline,lint' -S %s 2>%t.stderr | FileCheck %s +; RUN: opt --verify-each -passes='dxil-cont-lgc-rt-op-converter,lint,dxil-cont-intrinsic-prepare,lint' -S %s 2>%t.stderr | FileCheck %s ; RUN: count 0 < %t.stderr %struct.DispatchSystemData = type { i32 } @@ -13,15 +13,10 @@ declare %struct.DispatchSystemData @_cont_SetupRayGen() declare !types !8 i32 @_cont_GetLocalRootIndex(%struct.DispatchSystemData*) define void @main() { -; CHECK-LABEL: define void @main() !lgc.rt.shaderstage !6 !continuation.entry !12 !continuation.registercount !6 !continuation !13 { +; CHECK-LABEL: define void @main() !lgc.rt.shaderstage !5 { ; CHECK-NEXT: entry: -; CHECK-NEXT: [[SYSTEM_DATA_ALLOCA:%.*]] = alloca [[STRUCT_DISPATCHSYSTEMDATA:%.*]], align 8 -; CHECK-NEXT: [[TMP0:%.*]] = call [[STRUCT_DISPATCHSYSTEMDATA]] @continuations.getSystemData.s_struct.DispatchSystemDatas() -; CHECK-NEXT: store [[STRUCT_DISPATCHSYSTEMDATA]] [[TMP0]], ptr [[SYSTEM_DATA_ALLOCA]], align 4 -; CHECK-NEXT: [[LOCAL_ROOT_INDEX:%.*]] = call i32 @_cont_GetLocalRootIndex(ptr [[SYSTEM_DATA_ALLOCA]]) -; CHECK-NEXT: call void @amd.dx.setLocalRootIndex(i32 [[LOCAL_ROOT_INDEX]]) ; CHECK-NEXT: store i1 true, ptr @debug_global, align 1 -; CHECK-NEXT: ret void, !continuation.registercount !9 +; CHECK-NEXT: ret void ; entry: %val = call i1 @_AmdContinuationStackIsGlobal() diff --git a/shared/continuations/test/dx/intrinsics/get-flags.ll b/shared/continuations/test/dx/intrinsics/get-flags.ll new file mode 100644 index 0000000000..5c2aafb0ff --- /dev/null +++ b/shared/continuations/test/dx/intrinsics/get-flags.ll @@ -0,0 +1,24 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 3 +; RUN: opt --verify-each -passes='dxil-cont-intrinsic-prepare,lint' -S %s 2>%t.stderr | FileCheck %s +; RUN: count 0 < %t.stderr + +declare i32 @_AmdContinuationsGetFlags() + +@debug_global = external global i32 + +define void @main() !lgc.rt.shaderstage !1 { +; CHECK-LABEL: define void @main() !lgc.rt.shaderstage !1 { +; CHECK-NEXT: entry: +; CHECK-NEXT: store i32 3, ptr @debug_global, align 4 +; CHECK-NEXT: ret void +; +entry: + %val = call i32 @_AmdContinuationsGetFlags() + store i32 %val, ptr @debug_global + ret void +} + +!continuation.flags = !{!0} + +!0 = !{i32 3} +!1 = !{i32 0} diff --git a/shared/continuations/test/dx/intrinsics/get-rtip.ll b/shared/continuations/test/dx/intrinsics/get-rtip.ll new file mode 100644 index 0000000000..51506d8251 --- /dev/null +++ b/shared/continuations/test/dx/intrinsics/get-rtip.ll @@ -0,0 +1,30 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 3 +; RUN: opt --verify-each -passes='dxil-cont-intrinsic-prepare,lint' -S %s 2>%t.stderr | FileCheck %s +; RUN: count 0 < %t.stderr + +declare i32 @_AmdGetRtip() + +%struct.DispatchSystemData = type { i32 } +declare %struct.DispatchSystemData @_cont_SetupRayGen() +declare !types !8 i32 @_cont_GetLocalRootIndex(%struct.DispatchSystemData*) + +@debug_global = external global i32 + +define void @main() !lgc.rt.shaderstage !1 { +; CHECK-LABEL: define void @main() !lgc.rt.shaderstage !3 { +; CHECK-NEXT: entry: +; CHECK-NEXT: store i32 2, ptr @debug_global, align 4 +; CHECK-NEXT: ret void +; +entry: + %val = call i32 @_AmdGetRtip() + store i32 %val, ptr @debug_global + ret void +} + +!continuation.rtip = !{!0} + +!0 = !{i32 2} +!1 = !{i32 0} +!8 = !{!"function", i32 poison, !9} +!9 = !{i32 0, %struct.DispatchSystemData poison} diff --git a/shared/continuations/test/dx/lower-await.ll b/shared/continuations/test/dx/lower-await.ll index 8da928e9dc..cbec15e004 100644 --- a/shared/continuations/test/dx/lower-await.ll +++ b/shared/continuations/test/dx/lower-await.ll @@ -13,6 +13,7 @@ target datalayout = "e-m:e-p:64:32-p20:32:32-p21:32:32-i1:32-i8:8-i16:32-i32:32- declare void @await.void(%continuation.token*) declare i32 @await.i32(%continuation.token*) declare %continuation.token* @async_fun() +declare %continuation.token* @async_fun_with_waitmask(i64) declare %continuation.token* @async_fun_with_arg(i32) define void @simple_await() !continuation.registercount !1 { @@ -195,6 +196,55 @@ define i32 @await_with_ret_value() !continuation.registercount !1 { ret i32 %res, !continuation.registercount !1 } +define void @wait_await() !continuation.registercount !1 { +; AWAIT-LABEL: define { ptr, ptr } @wait_await( +; AWAIT-SAME: i32 [[CSPINIT:%.*]], i64 [[RETURNADDR:%.*]], ptr [[TMP0:%.*]]) !continuation.registercount !1 !continuation !7 { +; AWAIT-NEXT: [[TMP2:%.*]] = call token @llvm.coro.id.retcon(i32 8, i32 4, ptr [[TMP0]], ptr @continuation.prototype.wait_await, ptr @continuation.malloc, ptr @continuation.free) +; AWAIT-NEXT: [[TMP3:%.*]] = call ptr @llvm.coro.begin(token [[TMP2]], ptr null) +; AWAIT-NEXT: [[TOK:%.*]] = call ptr @async_fun_with_waitmask(i64 -1), !continuation.registercount !1, !continuation.returnedRegistercount !1, !continuation.wait.await !3 +; AWAIT-NEXT: [[TMP4:%.*]] = call i1 (...) @llvm.coro.suspend.retcon.i1(ptr [[TOK]]) +; AWAIT-NEXT: call void (...) @continuation.return(i64 [[RETURNADDR]]), !continuation.registercount !1 +; AWAIT-NEXT: unreachable +; +; CORO-LABEL: define { ptr, ptr } @wait_await( +; CORO-SAME: i32 [[CSPINIT:%.*]], i64 [[RETURNADDR:%.*]], ptr [[TMP0:%.*]]) !continuation.registercount !1 !continuation !7 { +; CORO-NEXT: AllocaSpillBB: +; CORO-NEXT: [[RETURNADDR_SPILL_ADDR:%.*]] = getelementptr inbounds [[WAIT_AWAIT_FRAME:%.*]], ptr [[TMP0]], i32 0, i32 0 +; CORO-NEXT: store i64 [[RETURNADDR]], ptr [[RETURNADDR_SPILL_ADDR]], align 4 +; CORO-NEXT: [[TOK:%.*]] = call ptr @async_fun_with_waitmask(i64 -1), !continuation.registercount !1, !continuation.returnedRegistercount !1, !continuation.wait.await !3 +; CORO-NEXT: [[TMP1:%.*]] = insertvalue { ptr, ptr } [[UNDEF_OR_POISON:undef|poison]], ptr @wait_await.resume.0, 0 +; CORO-NEXT: [[TMP2:%.*]] = insertvalue { ptr, ptr } [[TMP1]], ptr [[TOK]], 1 +; CORO-NEXT: ret { ptr, ptr } [[TMP2]] +; +; CLEANED-LABEL: define void @wait_await( +; CLEANED-SAME: i32 [[CSPINIT:%.*]], i64 [[RETURNADDR:%.*]]) !continuation.registercount !2 !continuation !9 !continuation.state !4 !continuation.stacksize !4 { +; CLEANED-NEXT: AllocaSpillBB: +; CLEANED-NEXT: [[CONT_STATE:%.*]] = alloca [2 x i32], align 4 +; CLEANED-NEXT: call void @continuation.save.continuation_state() +; CLEANED-NEXT: [[RETURNADDR_SPILL_ADDR:%.*]] = getelementptr inbounds [[WAIT_AWAIT_FRAME:%.*]], ptr [[CONT_STATE]], i32 0, i32 0 +; CLEANED-NEXT: store i64 [[RETURNADDR]], ptr [[RETURNADDR_SPILL_ADDR]], align 4 +; CLEANED-NEXT: [[TMP0:%.*]] = call ptr @continuation.getContinuationStackOffset() +; CLEANED-NEXT: [[TMP1:%.*]] = load i32, ptr [[TMP0]], align 4 +; CLEANED-NEXT: [[TMP2:%.*]] = add i32 [[TMP1]], 8 +; CLEANED-NEXT: store i32 [[TMP2]], ptr [[TMP0]], align 4 +; CLEANED-NEXT: [[TMP3:%.*]] = call ptr @continuation.getContinuationStackOffset() +; CLEANED-NEXT: call void (...) @registerbuffer.setpointerbarrier(ptr @CONTINUATION_STATE, ptr [[TMP3]]) +; CLEANED-NEXT: [[TMP4:%.*]] = getelementptr inbounds [2 x i32], ptr [[CONT_STATE]], i32 0, i32 0 +; CLEANED-NEXT: [[TMP5:%.*]] = load i32, ptr [[TMP4]], align 4 +; CLEANED-NEXT: store i32 [[TMP5]], ptr @CONTINUATION_STATE, align 4 +; CLEANED-NEXT: [[TMP6:%.*]] = getelementptr inbounds [2 x i32], ptr [[CONT_STATE]], i32 0, i32 1 +; CLEANED-NEXT: [[TMP7:%.*]] = load i32, ptr [[TMP6]], align 4 +; CLEANED-NEXT: store i32 [[TMP7]], ptr getelementptr inbounds ([2 x i32], ptr @CONTINUATION_STATE, i32 0, i32 1), align 4 +; CLEANED-NEXT: [[TMP8:%.*]] = call ptr @continuation.getContinuationStackOffset() +; CLEANED-NEXT: [[TMP9:%.*]] = load i32, ptr [[TMP8]], align 4 +; CLEANED-NEXT: call void (i64, i64, ...) @continuation.waitContinue(i64 ptrtoint (ptr @async_fun_with_waitmask to i64), i64 -1, i32 [[TMP9]], i64 ptrtoint (ptr @wait_await.resume.0 to i64)), !continuation.registercount !2, !continuation.returnedRegistercount !2 +; CLEANED-NEXT: unreachable +; + %tok = call %continuation.token* @async_fun_with_waitmask(i64 -1), !continuation.wait.await !0, !continuation.registercount !1, !continuation.returnedRegistercount !1 + call void @await.void(%continuation.token* %tok) + ret void, !continuation.registercount !1 +} + !continuation.stackAddrspace = !{!2} !0 = !{} diff --git a/shared/continuations/test/dx/lower-rt-pipeline-simple-call-shader.ll b/shared/continuations/test/dx/lower-rt-pipeline-simple-call-shader.ll index 9988a633dc..f69cb7323b 100644 --- a/shared/continuations/test/dx/lower-rt-pipeline-simple-call-shader.ll +++ b/shared/continuations/test/dx/lower-rt-pipeline-simple-call-shader.ll @@ -9,6 +9,9 @@ ; RUN: count 0 < %t2.stderr ; RUN: opt --verify-each -passes='dxil-cont-lgc-rt-op-converter,lint,lower-raytracing-pipeline,lint,remove-types-metadata' -S %s 2>%t3.stderr | FileCheck -check-prefix=LOWERRAYTRACINGPIPELINE-CPS %s ; RUN: count 0 < %t3.stderr +; RUN: opt --verify-each -passes='dxil-cont-lgc-rt-op-converter,lint,lower-raytracing-pipeline,lint,inline,lint,pre-coroutine-lowering,lint,sroa,lint,lower-await,lint,coro-early,dxil-coro-split,coro-cleanup,lint,cleanup-continuations,lint,remove-types-metadata' \ +; RUN: -S %s 2>%t4.stderr | FileCheck -check-prefix=CLEANUP-CPS %s +; RUN: count 0 < %t4.stderr target datalayout = "e-m:e-p:64:32-p20:32:32-p21:32:32-i1:32-i8:8-i16:32-i32:32-i64:32-f16:32-f32:32-f64:32-v16:32-v32:32-v48:32-v64:32-v80:32-v96:32-v112:32-v128:32-v144:32-v160:32-v176:32-v192:32-v208:32-v224:32-v240:32-v256:32-n8:16:32" @@ -38,7 +41,7 @@ define i32 @_cont_GetLocalRootIndex(%struct.DispatchSystemData* %data) !types !1 ret i32 5 } -define void @_cont_CallShader(%struct.DispatchSystemData* %data, i32 %0) !types !18 { +define void @_cont_CallShader(%struct.DispatchSystemData* %data, i32 %0) #1 !types !18 { %dis_data = load %struct.DispatchSystemData, %struct.DispatchSystemData* %data, align 4 %newdata = call %struct.DispatchSystemData @_AmdAwaitShader(i64 2, %struct.DispatchSystemData %dis_data) store %struct.DispatchSystemData %newdata, %struct.DispatchSystemData* %data, align 4 @@ -55,6 +58,7 @@ define void @called(%struct.MyParams* %params) !types !19 { declare !types !21 void @dx.op.callShader.struct.MyParams(i32, i32, %struct.MyParams*) #0 attributes #0 = { nounwind } +attributes #1 = { alwaysinline } !llvm.ident = !{!0} !dx.version = !{!1} @@ -91,7 +95,7 @@ attributes #0 = { nounwind } ; ; ; LOWERRAYTRACINGPIPELINE-LABEL: define void @_cont_CallShader.struct.MyParams( -; LOWERRAYTRACINGPIPELINE-SAME: ptr [[DATA:%.*]], i32 [[TMP0:%.*]], ptr [[TMP1:%.*]]) { +; LOWERRAYTRACINGPIPELINE-SAME: ptr [[DATA:%.*]], i32 [[TMP0:%.*]], ptr [[TMP1:%.*]]) #[[ATTR0:[0-9]+]] { ; LOWERRAYTRACINGPIPELINE-NEXT: [[DIS_DATA:%.*]] = load [[STRUCT_DISPATCHSYSTEMDATA:%.*]], ptr [[DATA]], align 4 ; LOWERRAYTRACINGPIPELINE-NEXT: [[TMP3:%.*]] = getelementptr inbounds [[STRUCT_MYPARAMS:%.*]], ptr [[TMP1]], i32 0, i32 0 ; LOWERRAYTRACINGPIPELINE-NEXT: [[TMP4:%.*]] = load i32, ptr [[TMP3]], align 4 @@ -110,7 +114,7 @@ attributes #0 = { nounwind } ; ; ; LOWERRAYTRACINGPIPELINE-LABEL: define void @_cont_CallShader( -; LOWERRAYTRACINGPIPELINE-SAME: ptr [[DATA:%.*]], i32 [[TMP0:%.*]]) { +; LOWERRAYTRACINGPIPELINE-SAME: ptr [[DATA:%.*]], i32 [[TMP0:%.*]]) #[[ATTR0]] { ; LOWERRAYTRACINGPIPELINE-NEXT: [[DIS_DATA:%.*]] = load [[STRUCT_DISPATCHSYSTEMDATA:%.*]], ptr [[DATA]], align 4 ; LOWERRAYTRACINGPIPELINE-NEXT: [[NEWDATA:%.*]] = call [[STRUCT_DISPATCHSYSTEMDATA]] @_AmdAwaitShader(i64 2, [[STRUCT_DISPATCHSYSTEMDATA]] [[DIS_DATA]]) ; LOWERRAYTRACINGPIPELINE-NEXT: store [[STRUCT_DISPATCHSYSTEMDATA]] [[NEWDATA]], ptr [[DATA]], align 4 @@ -303,7 +307,7 @@ attributes #0 = { nounwind } ; ; ; LOWERRAYTRACINGPIPELINE-CPS-LABEL: define void @_cont_CallShader.struct.MyParams( -; LOWERRAYTRACINGPIPELINE-CPS-SAME: ptr [[DATA:%.*]], i32 [[TMP0:%.*]], ptr [[TMP1:%.*]]) { +; LOWERRAYTRACINGPIPELINE-CPS-SAME: ptr [[DATA:%.*]], i32 [[TMP0:%.*]], ptr [[TMP1:%.*]]) #[[ATTR0:[0-9]+]] { ; LOWERRAYTRACINGPIPELINE-CPS-NEXT: [[DIS_DATA:%.*]] = load [[STRUCT_DISPATCHSYSTEMDATA:%.*]], ptr [[DATA]], align 4 ; LOWERRAYTRACINGPIPELINE-CPS-NEXT: [[TMP3:%.*]] = getelementptr inbounds [[STRUCT_MYPARAMS:%.*]], ptr [[TMP1]], i32 0, i32 0 ; LOWERRAYTRACINGPIPELINE-CPS-NEXT: [[TMP4:%.*]] = load i32, ptr [[TMP3]], align 4 @@ -321,7 +325,7 @@ attributes #0 = { nounwind } ; ; ; LOWERRAYTRACINGPIPELINE-CPS-LABEL: define void @_cont_CallShader( -; LOWERRAYTRACINGPIPELINE-CPS-SAME: ptr [[DATA:%.*]], i32 [[TMP0:%.*]]) { +; LOWERRAYTRACINGPIPELINE-CPS-SAME: ptr [[DATA:%.*]], i32 [[TMP0:%.*]]) #[[ATTR0]] { ; LOWERRAYTRACINGPIPELINE-CPS-NEXT: [[DIS_DATA:%.*]] = load [[STRUCT_DISPATCHSYSTEMDATA:%.*]], ptr [[DATA]], align 4 ; LOWERRAYTRACINGPIPELINE-CPS-NEXT: [[NEWDATA:%.*]] = call [[STRUCT_DISPATCHSYSTEMDATA]] @_AmdAwaitShader(i64 2, [[STRUCT_DISPATCHSYSTEMDATA]] [[DIS_DATA]]) ; LOWERRAYTRACINGPIPELINE-CPS-NEXT: store [[STRUCT_DISPATCHSYSTEMDATA]] [[NEWDATA]], ptr [[DATA]], align 4 @@ -348,3 +352,41 @@ attributes #0 = { nounwind } ; LOWERRAYTRACINGPIPELINE-CPS-NEXT: call void (...) @lgc.cps.jump(i32 [[RETURN_ADDR]], i32 2, {} poison, [[STRUCT_DISPATCHSYSTEMDATA]] [[TMP6]]) ; LOWERRAYTRACINGPIPELINE-CPS-NEXT: unreachable ; +; +; CLEANUP-CPS-LABEL: define i32 @_cont_GetLocalRootIndex( +; CLEANUP-CPS-SAME: ptr [[DATA:%.*]]) { +; CLEANUP-CPS-NEXT: ret i32 5 +; +; +; CLEANUP-CPS-LABEL: define void @called( +; CLEANUP-CPS-SAME: {} [[CONT_STATE:%.*]], i32 [[RETURN_ADDR:%.*]], i32 [[SHADER_INDEX:%.*]], [[STRUCT_DISPATCHSYSTEMDATA:%.*]] [[TMP0:%.*]]) !lgc.rt.shaderstage !16 !lgc.cps !17 !continuation !18 { +; CLEANUP-CPS-NEXT: AllocaSpillBB: +; CLEANUP-CPS-NEXT: [[TMP1:%.*]] = call ptr addrspace(32) @lgc.cps.alloc(i32 8) +; CLEANUP-CPS-NEXT: [[RETURN_ADDR_SPILL_ADDR:%.*]] = getelementptr inbounds [[CALLED_FRAME:%.*]], ptr addrspace(32) [[TMP1]], i32 0, i32 0 +; CLEANUP-CPS-NEXT: store i32 [[RETURN_ADDR]], ptr addrspace(32) [[RETURN_ADDR_SPILL_ADDR]], align 4 +; CLEANUP-CPS-NEXT: [[TMP2:%.*]] = call [[STRUCT_DISPATCHSYSTEMDATA]] @continuations.getSystemData.s_struct.DispatchSystemDatas() +; CLEANUP-CPS-NEXT: [[DOTFCA_0_EXTRACT:%.*]] = extractvalue [[STRUCT_DISPATCHSYSTEMDATA]] [[TMP2]], 0 +; CLEANUP-CPS-NEXT: call void @amd.dx.setLocalRootIndex(i32 5) +; CLEANUP-CPS-NEXT: [[DIS_DATA_I_FCA_0_INSERT:%.*]] = insertvalue [[STRUCT_DISPATCHSYSTEMDATA]] poison, i32 [[DOTFCA_0_EXTRACT]], 0 +; CLEANUP-CPS-NEXT: store i32 undef, ptr @PAYLOAD, align 4 +; CLEANUP-CPS-NEXT: [[TMP3:%.*]] = call i32 (...) @lgc.cps.as.continuation.reference(ptr @called.resume.0) +; CLEANUP-CPS-NEXT: call void (...) @lgc.cps.jump(i32 2, i32 2, {} poison, i32 [[TMP3]], [[STRUCT_DISPATCHSYSTEMDATA]] [[DIS_DATA_I_FCA_0_INSERT]]) +; CLEANUP-CPS-NEXT: unreachable +; +; +; CLEANUP-CPS-LABEL: define void @called.resume.0( +; CLEANUP-CPS-SAME: {} [[TMP0:%.*]], i32 [[TMP1:%.*]], [[STRUCT_DISPATCHSYSTEMDATA:%.*]] [[TMP2:%.*]]) !lgc.rt.shaderstage !16 !lgc.cps !17 !continuation !18 { +; CLEANUP-CPS-NEXT: entryresume.0: +; CLEANUP-CPS-NEXT: [[TMP3:%.*]] = call ptr addrspace(32) @lgc.cps.peek(i32 8) +; CLEANUP-CPS-NEXT: [[TMP4:%.*]] = load i32, ptr @PAYLOAD, align 4 +; CLEANUP-CPS-NEXT: [[DOTFCA_0_EXTRACT3:%.*]] = extractvalue [[STRUCT_DISPATCHSYSTEMDATA]] [[TMP2]], 0 +; CLEANUP-CPS-NEXT: call void @amd.dx.setLocalRootIndex(i32 5) +; CLEANUP-CPS-NEXT: [[RETURN_ADDR_RELOAD_ADDR:%.*]] = getelementptr inbounds [[CALLED_FRAME:%.*]], ptr addrspace(32) [[TMP3]], i32 0, i32 0 +; CLEANUP-CPS-NEXT: [[RETURN_ADDR_RELOAD:%.*]] = load i32, ptr addrspace(32) [[RETURN_ADDR_RELOAD_ADDR]], align 4 +; CLEANUP-CPS-NEXT: call void (...) @registerbuffer.setpointerbarrier(ptr @PAYLOAD) +; CLEANUP-CPS-NEXT: store i32 [[TMP4]], ptr @PAYLOAD, align 4 +; CLEANUP-CPS-NEXT: [[DOTFCA_0_INSERT:%.*]] = insertvalue [[STRUCT_DISPATCHSYSTEMDATA]] poison, i32 [[DOTFCA_0_EXTRACT3]], 0 +; CLEANUP-CPS-NEXT: call void @lgc.cps.free(i32 8) +; CLEANUP-CPS-NEXT: call void (...) @lgc.cps.jump(i32 [[RETURN_ADDR_RELOAD]], i32 2, {} poison, [[STRUCT_DISPATCHSYSTEMDATA]] [[DOTFCA_0_INSERT]]) +; CLEANUP-CPS-NEXT: unreachable +; diff --git a/shared/continuations/test/dx/traceray.ll b/shared/continuations/test/dx/traceray.ll index c71bc1c233..98f262bce3 100644 --- a/shared/continuations/test/dx/traceray.ll +++ b/shared/continuations/test/dx/traceray.ll @@ -26,7 +26,9 @@ declare i32 @_cont_GetContinuationStackAddr() #0 declare %struct.DispatchSystemData @_cont_SetupRayGen() #0 -declare %struct.DispatchSystemData @_AmdAwaitTraversal(i64, %struct.TraversalData) #0 +; To exercise both waiting and non-waiting Await, we use WaitAwait for Traversal, +; and Await for Callshader. This does not necessarily reflect current choices in GPURT. +declare %struct.DispatchSystemData @_AmdWaitAwaitTraversal(i64, i64, %struct.TraversalData) #0 declare %struct.DispatchSystemData @_AmdAwaitShader(i64, %struct.DispatchSystemData) #0 @@ -80,7 +82,7 @@ define void @_cont_TraceRay(%struct.DispatchSystemData* %data, i64 %0, i32 %1, i %trav_data = insertvalue %struct.TraversalData undef, %struct.SystemData %sys_data, 0 %addr = call i64 @_AmdGetResumePointAddr() #3 %trav_data2 = insertvalue %struct.TraversalData %trav_data, i64 %addr, 5 - %newdata = call %struct.DispatchSystemData @_AmdAwaitTraversal(i64 4, %struct.TraversalData %trav_data2) + %newdata = call %struct.DispatchSystemData @_AmdWaitAwaitTraversal(i64 4, i64 -1, %struct.TraversalData %trav_data2) store %struct.DispatchSystemData %newdata, %struct.DispatchSystemData* %data, align 4 call void @_AmdRestoreSystemData(%struct.DispatchSystemData* %data) ret void @@ -413,7 +415,7 @@ attributes #6 = { nocallback nofree nosync nounwind willreturn memory(argmem: re ; LOWERRAYTRACINGPIPELINE-NEXT: [[TMP25:%.*]] = getelementptr i32, ptr [[TMP20]], i64 2 ; LOWERRAYTRACINGPIPELINE-NEXT: [[TMP26:%.*]] = load i32, ptr [[TMP25]], align 4 ; LOWERRAYTRACINGPIPELINE-NEXT: store i32 [[TMP26]], ptr getelementptr ([[STRUCT_RAYPAYLOAD_ATTR_MAX_8_I32S_LAYOUT_0_CALLER_OUT]], ptr @PAYLOAD, i32 0, i32 0, i64 9), align 4 -; LOWERRAYTRACINGPIPELINE-NEXT: [[TMP27:%.*]] = call ptr inttoptr (i64 4 to ptr)([[STRUCT_TRAVERSALDATA]] [[TRAV_DATA2]]), !continuation.registercount !35, !continuation.returnedRegistercount !35 +; LOWERRAYTRACINGPIPELINE-NEXT: [[TMP27:%.*]] = call ptr inttoptr (i64 4 to ptr)(i64 -1, [[STRUCT_TRAVERSALDATA]] [[TRAV_DATA2]]), !continuation.registercount !35, !continuation.returnedRegistercount !35, !continuation.wait.await !14 ; LOWERRAYTRACINGPIPELINE-NEXT: [[TMP28:%.*]] = call [[STRUCT_DISPATCHSYSTEMDATA]] @await.struct.DispatchSystemData(ptr [[TMP27]]) ; LOWERRAYTRACINGPIPELINE-NEXT: store [[STRUCT_RAYPAYLOAD]] poison, ptr [[TMP14]], align 4 ; LOWERRAYTRACINGPIPELINE-NEXT: [[TMP29:%.*]] = getelementptr inbounds [[STRUCT_RAYPAYLOAD]], ptr [[TMP14]], i32 0, i32 0 @@ -445,7 +447,7 @@ attributes #6 = { nocallback nofree nosync nounwind willreturn memory(argmem: re ; LOWERRAYTRACINGPIPELINE-NEXT: [[TRAV_DATA:%.*]] = insertvalue [[STRUCT_TRAVERSALDATA:%.*]] undef, [[STRUCT_SYSTEMDATA]] [[SYS_DATA]], 0 ; LOWERRAYTRACINGPIPELINE-NEXT: [[ADDR:%.*]] = call i64 @_AmdGetResumePointAddr() #[[ATTR3]] ; LOWERRAYTRACINGPIPELINE-NEXT: [[TRAV_DATA2:%.*]] = insertvalue [[STRUCT_TRAVERSALDATA]] [[TRAV_DATA]], i64 [[ADDR]], 5 -; LOWERRAYTRACINGPIPELINE-NEXT: [[NEWDATA:%.*]] = call [[STRUCT_DISPATCHSYSTEMDATA]] @_AmdAwaitTraversal(i64 4, [[STRUCT_TRAVERSALDATA]] [[TRAV_DATA2]]) +; LOWERRAYTRACINGPIPELINE-NEXT: [[NEWDATA:%.*]] = call [[STRUCT_DISPATCHSYSTEMDATA]] @_AmdWaitAwaitTraversal(i64 4, i64 -1, [[STRUCT_TRAVERSALDATA]] [[TRAV_DATA2]]) ; LOWERRAYTRACINGPIPELINE-NEXT: store [[STRUCT_DISPATCHSYSTEMDATA]] [[NEWDATA]], ptr [[DATA]], align 4 ; LOWERRAYTRACINGPIPELINE-NEXT: [[LOCAL_ROOT_INDEX:%.*]] = call i32 @_cont_GetLocalRootIndex(ptr [[DATA]]) ; LOWERRAYTRACINGPIPELINE-NEXT: call void @amd.dx.setLocalRootIndex(i32 [[LOCAL_ROOT_INDEX]]) @@ -1002,7 +1004,7 @@ attributes #6 = { nocallback nofree nosync nounwind willreturn memory(argmem: re ; DXILCONTPOSTPROCESS-NEXT: [[TMP12:%.*]] = bitcast float [[DOTSROA_0_12_VEC_EXTRACT]] to i32 ; DXILCONTPOSTPROCESS-NEXT: store i32 [[TMP12]], ptr addrspace(20) addrspacecast (ptr getelementptr ([[STRUCT_RAYPAYLOAD_ATTR_MAX_8_I32S_LAYOUT_0_CALLER_OUT]], ptr addrspacecast (ptr addrspace(20) @REGISTERS to ptr), i32 0, i32 0, i64 9) to ptr addrspace(20)), align 4 ; DXILCONTPOSTPROCESS-NEXT: [[TMP13:%.*]] = load i32, ptr [[CSP]], align 4 -; DXILCONTPOSTPROCESS-NEXT: call void (i64, ...) @continuation.continue(i64 4, i32 [[TMP13]], [[STRUCT_TRAVERSALDATA]] [[TRAV_DATA2_I]]), !continuation.registercount !35, !continuation.returnedRegistercount !35 +; DXILCONTPOSTPROCESS-NEXT: call void (i64, i64, ...) @continuation.waitContinue(i64 4, i64 -1, i32 [[TMP13]], [[STRUCT_TRAVERSALDATA]] [[TRAV_DATA2_I]]), !continuation.registercount !35, !continuation.returnedRegistercount !35 ; DXILCONTPOSTPROCESS-NEXT: unreachable ; ; diff --git a/shared/continuations/test/intrinsics/discard-values.ll b/shared/continuations/test/intrinsics/discard-values.ll index 6059c40f61..399fbeef80 100644 --- a/shared/continuations/test/intrinsics/discard-values.ll +++ b/shared/continuations/test/intrinsics/discard-values.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 -; RUN: opt --verify-each -passes='dxil-cont-intrinsic-prepare,sroa,lint,lower-raytracing-pipeline,lint' -S %s 2>%t.stderr | FileCheck %s +; RUN: opt --verify-each -passes='dxil-cont-intrinsic-prepare,lint' -S %s 2>%t.stderr | FileCheck %s ; RUN: count 0 < %t.stderr %struct.AnyHitData = type { float, i32 } diff --git a/shared/continuations/test/lgccps/alloca-select.ll b/shared/continuations/test/lgccps/alloca-select.ll index a2e12fa45a..32b00e3f1f 100644 --- a/shared/continuations/test/lgccps/alloca-select.ll +++ b/shared/continuations/test/lgccps/alloca-select.ll @@ -1,18 +1,18 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --include-generated-funcs --version 2 -; RUN: opt -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s +; RUN: opt --verify-each -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s declare !lgc.cps !0 void @callee({}, i32, float) define void @test({} %state, i32 %rcr, float %arg, i32 %arg1) !lgc.cps !0 { - %a1 = alloca i32, align 4, addrspace(5) - %a2 = alloca i32, align 4, addrspace(5) + %a1 = alloca i32 + %a2 = alloca i32 %cond = icmp ult i32 %arg1, 0 - %p = select i1 %cond, ptr addrspace(5) %a1, ptr addrspace(5) %a2 - store i32 111, ptr addrspace(5) %p, align 4 + %p = select i1 %cond, ptr %a1, ptr %a2 + store i32 111, ptr %p, align 4 %t0 = fadd float %arg, 1.0 %cr = call i32 @lgc.cps.as.continuation.reference(ptr @callee) %t1 = call float (...) @lgc.cps.await.f32(i32 %cr, i32 2, float %t0) %tmp = fmul float %t1, %arg - %v111 = load float, ptr addrspace(5) %p, align 4 + %v111 = load float, ptr %p, align 4 %returnvalue = fmul float %tmp, %v111 call void (...) @lgc.cps.jump(i32 %rcr, i32 2, {} poison, i32 poison, float %returnvalue) unreachable diff --git a/shared/continuations/test/lgccps/await-if-else.ll b/shared/continuations/test/lgccps/await-if-else.ll index a02dab64ad..e6380db3bd 100644 --- a/shared/continuations/test/lgccps/await-if-else.ll +++ b/shared/continuations/test/lgccps/await-if-else.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --include-generated-funcs --version 3 -; RUN: opt -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s +; RUN: opt --verify-each -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s declare !lgc.cps !0 void @callee({}, i32, float) declare !lgc.cps !0 void @callee2({}, i32, float) diff --git a/shared/continuations/test/lgccps/await-if.ll b/shared/continuations/test/lgccps/await-if.ll index 429d9020b8..6f5f78615c 100644 --- a/shared/continuations/test/lgccps/await-if.ll +++ b/shared/continuations/test/lgccps/await-if.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --include-generated-funcs --version 3 -; RUN: opt -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s +; RUN: opt --verify-each -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s declare !lgc.cps !0 void @callee({}, i32, float) diff --git a/shared/continuations/test/lgccps/await-in-loop.ll b/shared/continuations/test/lgccps/await-in-loop.ll index 2e46df4d2f..76e2f1d75e 100644 --- a/shared/continuations/test/lgccps/await-in-loop.ll +++ b/shared/continuations/test/lgccps/await-in-loop.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --include-generated-funcs --version 3 -; RUN: opt -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s +; RUN: opt --verify-each -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s declare !lgc.cps !0 void @callee({}, i32, i32) diff --git a/shared/continuations/test/lgccps/entry-point-with-cps.ll b/shared/continuations/test/lgccps/entry-point-with-cps.ll index 323bfec784..2d0ea41865 100644 --- a/shared/continuations/test/lgccps/entry-point-with-cps.ll +++ b/shared/continuations/test/lgccps/entry-point-with-cps.ll @@ -1,6 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --include-generated-funcs --version 3 -; RUN: opt -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s -; RUN: opt -S -o - -passes='lower-await' %s | FileCheck --check-prefixes=LOWER-AWAIT %s +; RUN: opt --verify-each -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s +; RUN: opt --verify-each -S -o - -passes='lower-await' %s | FileCheck --check-prefixes=LOWER-AWAIT %s ; This is example output for running continufy on the -in file. ; Details of the output are likely to differ from the final production pass, diff --git a/shared/continuations/test/lgccps/multiple-await.ll b/shared/continuations/test/lgccps/multiple-await.ll index f1b38920e0..3500f49d40 100644 --- a/shared/continuations/test/lgccps/multiple-await.ll +++ b/shared/continuations/test/lgccps/multiple-await.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --include-generated-funcs --version 3 -; RUN: opt -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s +; RUN: opt --verify-each -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s declare !lgc.cps !0 void @callee({}, i32, float) declare !lgc.cps !0 void @callee2({}, i32, float) diff --git a/shared/continuations/test/lgccps/simple-await-more-state.ll b/shared/continuations/test/lgccps/simple-await-more-state.ll index 036ef96d9e..7ed851fdcb 100644 --- a/shared/continuations/test/lgccps/simple-await-more-state.ll +++ b/shared/continuations/test/lgccps/simple-await-more-state.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --include-generated-funcs --version 3 -; RUN: opt -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s +; RUN: opt --verify-each -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s declare !lgc.cps !0 void @callee({}, i32, float) diff --git a/shared/continuations/test/lgccps/simple-await.ll b/shared/continuations/test/lgccps/simple-await.ll index d8386e0584..ae9dae00fd 100644 --- a/shared/continuations/test/lgccps/simple-await.ll +++ b/shared/continuations/test/lgccps/simple-await.ll @@ -1,6 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --include-generated-funcs --version 3 -; RUN: opt -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s -; RUN: opt -S -o - -passes='lower-await' %s | FileCheck --check-prefixes=LOWER-AWAIT %s +; RUN: opt --verify-each -S -o - -passes='lower-await,coro-early,lgc-coro-split,coro-cleanup,cleanup-continuations' %s | FileCheck --check-prefixes=CHECK %s +; RUN: opt --verify-each -S -o - -passes='lower-await' %s | FileCheck --check-prefixes=LOWER-AWAIT %s declare !lgc.cps !0 void @callee({}, i32, float) define void @test({} %state, i32 %rcr, float %arg) !lgc.cps !0 { diff --git a/tool/dumper/vkgcPipelineDumper.cpp b/tool/dumper/vkgcPipelineDumper.cpp index 3254e0e6ad..c2ee3a7074 100644 --- a/tool/dumper/vkgcPipelineDumper.cpp +++ b/tool/dumper/vkgcPipelineDumper.cpp @@ -862,6 +862,7 @@ void PipelineDumper::dumpPipelineOptions(const PipelineOptions *options, std::os dumpFile << "options.disableTruncCoordForGather = " << options->disableTruncCoordForGather << "\n"; dumpFile << "options.vertex64BitsAttribSingleLoc = " << options->vertex64BitsAttribSingleLoc << "\n"; dumpFile << "options.enablePrimGeneratedQuery = " << options->enablePrimGeneratedQuery << "\n"; + dumpFile << "options.enableFragColor = " << options->enableFragColor << "\n"; } // ===================================================================================================================== @@ -1017,7 +1018,7 @@ void PipelineDumper::dumpGraphicsStateInfo(const GraphicsPipelineBuildInfo *pipe dumpFile << "\n[ApiXfbOutInfo]\n"; dumpFile << "forceDisableStreamOut = " << pipelineInfo->apiXfbOutData.forceDisableStreamOut << "\n"; -#if LLPC_CLIENT_INTERFACE_MAJOR_VERSION < 69 +#if LLPC_CLIENT_INTERFACE_MAJOR_VERSION < 70 dumpFile << "forceEnablePrimStats = " << pipelineInfo->apiXfbOutData.forceEnablePrimStats << "\n"; #endif const auto pXfbOutInfos = pipelineInfo->apiXfbOutData.pXfbOutInfos; @@ -1569,7 +1570,7 @@ void PipelineDumper::updateHashForNonFragmentState(const GraphicsPipelineBuildIn } hasher->Update(pipeline->apiXfbOutData.forceDisableStreamOut); -#if LLPC_CLIENT_INTERFACE_MAJOR_VERSION < 69 +#if LLPC_CLIENT_INTERFACE_MAJOR_VERSION < 70 hasher->Update(pipeline->apiXfbOutData.forceEnablePrimStats); #endif } @@ -1659,6 +1660,7 @@ void PipelineDumper::updateHashForPipelineOptions(const PipelineOptions *options hasher->Update(options->replaceSetWithResourceType); hasher->Update(options->disableTruncCoordForGather); hasher->Update(options->enablePrimGeneratedQuery); + hasher->Update(options->enableFragColor); } // ===================================================================================================================== @@ -2574,6 +2576,8 @@ std::ostream &operator<<(std::ostream &out, VkFormat format) { CASE_ENUM_TO_STRING(VK_FORMAT_PVRTC2_4BPP_SRGB_BLOCK_IMG) CASE_ENUM_TO_STRING(VK_FORMAT_A4R4G4B4_UNORM_PACK16_EXT) CASE_ENUM_TO_STRING(VK_FORMAT_A4B4G4R4_UNORM_PACK16_EXT) + CASE_ENUM_TO_STRING(VK_FORMAT_A1B5G5R5_UNORM_PACK16) + CASE_ENUM_TO_STRING(VK_FORMAT_A8_UNORM_KHR) break; default: diff --git a/tool/vfx/vfxEnumsConverter.cpp b/tool/vfx/vfxEnumsConverter.cpp index 6fc6474507..314f0c6dec 100644 --- a/tool/vfx/vfxEnumsConverter.cpp +++ b/tool/vfx/vfxEnumsConverter.cpp @@ -376,6 +376,8 @@ void initEnumMap() { ADD_ENUM_MAP(VkFormat, VK_FORMAT_PVRTC2_4BPP_SRGB_BLOCK_IMG); ADD_ENUM_MAP(VkFormat, VK_FORMAT_A4R4G4B4_UNORM_PACK16_EXT); ADD_ENUM_MAP(VkFormat, VK_FORMAT_A4B4G4R4_UNORM_PACK16_EXT); + ADD_ENUM_MAP(VkFormat, VK_FORMAT_A1B5G5R5_UNORM_PACK16); + ADD_ENUM_MAP(VkFormat, VK_FORMAT_A8_UNORM_KHR); ADD_ENUM_MAP(VkFormat, VK_FORMAT_MAX_ENUM); ADD_ENUM_MAP(VkImageType, VK_IMAGE_TYPE_1D); ADD_ENUM_MAP(VkImageType, VK_IMAGE_TYPE_2D); diff --git a/tool/vfx/vfxVkSection.h b/tool/vfx/vfxVkSection.h index d42fcd4b34..d8ba2e5163 100644 --- a/tool/vfx/vfxVkSection.h +++ b/tool/vfx/vfxVkSection.h @@ -499,6 +499,7 @@ class SectionPipelineOption : public Section { INIT_STATE_MEMBER_NAME_TO_ADDR(SectionPipelineOption, buildResourcesDataForShaderModule, MemberTypeBool, false); INIT_STATE_MEMBER_NAME_TO_ADDR(SectionPipelineOption, disableTruncCoordForGather, MemberTypeBool, false); INIT_STATE_MEMBER_NAME_TO_ADDR(SectionPipelineOption, vertex64BitsAttribSingleLoc, MemberTypeBool, false); + INIT_STATE_MEMBER_NAME_TO_ADDR(SectionPipelineOption, enableFragColor, MemberTypeBool, false); INIT_STATE_MEMBER_NAME_TO_ADDR(SectionPipelineOption, enablePrimGeneratedQuery, MemberTypeBool, false); return addrTableInitializer; }(); @@ -1032,7 +1033,7 @@ class SectionApiXfbOutput : public Section { static std::vector addrTable = []() { std::vector addrTableInitializer; INIT_STATE_MEMBER_NAME_TO_ADDR(SectionApiXfbOutput, forceDisableStreamOut, MemberTypeBool, false); -#if LLPC_CLIENT_INTERFACE_MAJOR_VERSION < 69 +#if LLPC_CLIENT_INTERFACE_MAJOR_VERSION < 70 INIT_STATE_MEMBER_NAME_TO_ADDR(SectionApiXfbOutput, forceEnablePrimStats, MemberTypeBool, false); #endif INIT_MEMBER_DYNARRAY_NAME_TO_ADDR(SectionApiXfbOutput, m_xfbOutInfo, MemberTypeXfbOutInfo, true); diff --git a/util/extensions.txt b/util/extensions.txt index 95c58023fd..d3cf249784 100644 --- a/util/extensions.txt +++ b/util/extensions.txt @@ -41,4 +41,8 @@ SPV_KHR_ray_tracing SPV_KHR_ray_query SPV_KHR_fragment_shader_barycentric SPV_KHR_workgroup_memory_explicit_layout +#if VKI_COOPERATIVE_MATRIX +SPV_NV_cooperative_matrix +SPV_KHR_cooperative_matrix +#endif SPV_NV_shader_atomic_float diff --git a/util/vkgcCapability.h b/util/vkgcCapability.h index 96cbb1af89..9d2445353f 100644 --- a/util/vkgcCapability.h +++ b/util/vkgcCapability.h @@ -159,6 +159,7 @@ static const char *const VkgcSupportedCapabilities[] = { "CapabilityRayTraversalPrimitiveCullingProvisionalKHR", "CapabilityRayTracingPositionFetchKHR", "CapabilityRayQueryPositionFetchKHR", + "CapabilityCooperativeMatrixKHR", }; }; // namespace Vkgc