From e5ba891624d85ed3ccbb33ae470b13f9239c6c9e Mon Sep 17 00:00:00 2001 From: xazhang Date: Thu, 9 Nov 2023 11:51:11 -0500 Subject: [PATCH] lgc: Move handling of GroupMemcpy for mesh/task shaders to MeshTaskShader --- lgc/include/lgc/patch/Patch.h | 5 + lgc/include/lgc/patch/PatchEntryPointMutate.h | 4 +- lgc/patch/MeshTaskShader.cpp | 41 +++- lgc/patch/MeshTaskShader.h | 1 + lgc/patch/Patch.cpp | 65 ++++++ lgc/patch/PatchEntryPointMutate.cpp | 201 ++++-------------- 6 files changed, 158 insertions(+), 159 deletions(-) diff --git a/lgc/include/lgc/patch/Patch.h b/lgc/include/lgc/patch/Patch.h index 7c52020eae..aeeecc3faa 100644 --- a/lgc/include/lgc/patch/Patch.h +++ b/lgc/include/lgc/patch/Patch.h @@ -42,6 +42,8 @@ class PassBuilder; namespace lgc { +class BuilderBase; +class GroupMemcpyOp; class PipelineState; class PassManager; @@ -63,6 +65,9 @@ class Patch { static llvm::GlobalVariable *getLdsVariable(PipelineState *pipelineState, llvm::Module *module); + static void commonProcessGroupMemcpy(GroupMemcpyOp &groupMemcpyOp, lgc::BuilderBase &builder, + llvm::Value *threadIndex, unsigned scopeSize); + protected: static void addOptimizationPasses(lgc::PassManager &passMgr, uint32_t optLevel); diff --git a/lgc/include/lgc/patch/PatchEntryPointMutate.h b/lgc/include/lgc/patch/PatchEntryPointMutate.h index 40fd5ae48b..cce023bdea 100644 --- a/lgc/include/lgc/patch/PatchEntryPointMutate.h +++ b/lgc/include/lgc/patch/PatchEntryPointMutate.h @@ -162,8 +162,8 @@ class PatchEntryPointMutate : public Patch, public llvm::PassInfoMixin() .setStrategy(llvm_dialects::VisitorStrategy::ByFunctionDeclaration) + .add(&MeshTaskShader::lowerGroupMemcpy) .add(&MeshTaskShader::lowerTaskPayloadPtr) .add(&MeshTaskShader::lowerEmitMeshTasks) .build(); @@ -919,6 +920,44 @@ void MeshTaskShader::processMeshShader(Function *entryPoint) { } } +// ===================================================================================================================== +// Lower GroupMemcpyOp - copy memory using all threads in a workgroup. +void MeshTaskShader::lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp) { + Function *entryPoint = groupMemcpyOp.getFunction(); + ShaderStage stage = getShaderStage(entryPoint); + m_builder.SetInsertPoint(&groupMemcpyOp); + + unsigned scopeSize = 0; + Value *threadIndex = nullptr; + + auto scope = groupMemcpyOp.getScope(); + if (scope == 2) { + unsigned workgroupSize[3] = {}; + auto shaderModes = m_pipelineState->getShaderModes(); + if (stage == ShaderStageTask) { + Module &module = *groupMemcpyOp.getModule(); + workgroupSize[0] = shaderModes->getComputeShaderMode(module).workgroupSizeX; + workgroupSize[1] = shaderModes->getComputeShaderMode(module).workgroupSizeY; + workgroupSize[2] = shaderModes->getComputeShaderMode(module).workgroupSizeZ; + } else if (stage == ShaderStageMesh) { + workgroupSize[0] = shaderModes->getMeshShaderMode().workgroupSizeX; + workgroupSize[1] = shaderModes->getMeshShaderMode().workgroupSizeY; + workgroupSize[2] = shaderModes->getMeshShaderMode().workgroupSizeZ; + } else { + llvm_unreachable("Invalid shade stage!"); + } + + scopeSize = workgroupSize[0] * workgroupSize[1] * workgroupSize[2]; + threadIndex = m_waveThreadInfo.threadIdInSubgroup; + } else { + llvm_unreachable("Unsupported scope!"); + } + + Patch::commonProcessGroupMemcpy(groupMemcpyOp, m_builder, threadIndex, scopeSize); + + m_callsToRemove.push_back(&groupMemcpyOp); +} + // ===================================================================================================================== // Lower task payload pointer to buffer fat pointer. // @@ -2422,7 +2461,7 @@ Value *MeshTaskShader::getMeshLocalInvocationId() { // The local invocation ID is packed to VGPR0 on GFX11+ with the following layout: // // +-----------------------+-----------------------+-----------------------+ - // | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID Z | + // | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID X | // | [29:20] | [19:10] | [9:0] | // +-----------------------+-----------------------+-----------------------+ auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageMesh)->entryArgIdxs.mesh; diff --git a/lgc/patch/MeshTaskShader.h b/lgc/patch/MeshTaskShader.h index 788ee5dfd0..b6a4a18613 100644 --- a/lgc/patch/MeshTaskShader.h +++ b/lgc/patch/MeshTaskShader.h @@ -80,6 +80,7 @@ class MeshTaskShader { void processTaskShader(llvm::Function *entryPoint); void processMeshShader(llvm::Function *entryPoint); + void lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp); void lowerTaskPayloadPtr(TaskPayloadPtrOp &taskPayloadPtrOp); void lowerEmitMeshTasks(EmitMeshTasksOp &emitMeshTasksOp); void lowerSetMeshOutputs(SetMeshOutputsOp &setMeshOutputsOp); diff --git a/lgc/patch/Patch.cpp b/lgc/patch/Patch.cpp index 338c1f92ae..c622f43eba 100644 --- a/lgc/patch/Patch.cpp +++ b/lgc/patch/Patch.cpp @@ -100,6 +100,7 @@ #include "llvm/Transforms/Scalar/SimplifyCFG.h" #include "llvm/Transforms/Scalar/SpeculativeExecution.h" #include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Mem2Reg.h" #define DEBUG_TYPE "lgc-patch" @@ -459,4 +460,68 @@ GlobalVariable *Patch::getLdsVariable(PipelineState *pipelineState, Module *modu return lds; } +// ===================================================================================================================== +// Common code to be shared to implement GroupMemcpyOp for MeshTaskShader and PatchEntryPointMutate. +void Patch::commonProcessGroupMemcpy(GroupMemcpyOp &groupMemcpyOp, lgc::BuilderBase &builder, llvm::Value *threadIndex, + unsigned scopeSize) { + auto dst = groupMemcpyOp.getDst(); + auto src = groupMemcpyOp.getSrc(); + auto len = groupMemcpyOp.getSize(); + + // Copy in 16-bytes if possible + unsigned wideDwords = 4; + // If either pointer is in LDS, copy in 8-bytes + if (src->getType()->getPointerAddressSpace() == ADDR_SPACE_LOCAL || + dst->getType()->getPointerAddressSpace() == ADDR_SPACE_LOCAL) + wideDwords = 2; + + unsigned baseOffset = 0; + + auto copyFunc = [&](Type *copyTy, unsigned copySize) { + Value *offset = + builder.CreateAdd(builder.getInt32(baseOffset), builder.CreateMul(threadIndex, builder.getInt32(copySize))); + Value *dstPtr = builder.CreateGEP(builder.getInt8Ty(), dst, offset); + Value *srcPtr = builder.CreateGEP(builder.getInt8Ty(), src, offset); + Value *data = builder.CreateLoad(copyTy, srcPtr); + builder.CreateStore(data, dstPtr); + }; + + unsigned wideDwordsCopySize = sizeof(unsigned) * wideDwords; + Type *wideDwordsTy = ArrayType::get(builder.getInt32Ty(), wideDwords); + while (baseOffset + wideDwordsCopySize * scopeSize <= len) { + copyFunc(wideDwordsTy, wideDwordsCopySize); + baseOffset += wideDwordsCopySize * scopeSize; + } + + unsigned dwordCopySize = sizeof(unsigned); + Type *dwordTy = builder.getInt32Ty(); + while (baseOffset + dwordCopySize * scopeSize <= len) { + copyFunc(dwordTy, dwordCopySize); + baseOffset += dwordCopySize * scopeSize; + } + + unsigned remainingBytes = len - baseOffset; + + if (remainingBytes) { + assert(remainingBytes % 4 == 0); + BasicBlock *afterBlock = groupMemcpyOp.getParent(); + BasicBlock *beforeBlock = splitBlockBefore(afterBlock, &groupMemcpyOp, nullptr, nullptr, nullptr); + beforeBlock->takeName(afterBlock); + afterBlock->setName(Twine(beforeBlock->getName()) + ".afterGroupMemcpyTail"); + + // Split to create a tail copy block, empty except for an unconditional branch to afterBlock. + BasicBlock *copyBlock = splitBlockBefore(afterBlock, &groupMemcpyOp, nullptr, nullptr, nullptr, ".groupMemcpyTail"); + // Change the branch at the end of beforeBlock to be conditional. + beforeBlock->getTerminator()->eraseFromParent(); + builder.SetInsertPoint(beforeBlock); + + Value *indexInRange = builder.CreateICmpULT(threadIndex, builder.getInt32(remainingBytes / 4)); + + builder.CreateCondBr(indexInRange, copyBlock, afterBlock); + // Create the copy instructions. + builder.SetInsertPoint(copyBlock->getTerminator()); + copyFunc(dwordTy, dwordCopySize); + } +} + } // namespace lgc diff --git a/lgc/patch/PatchEntryPointMutate.cpp b/lgc/patch/PatchEntryPointMutate.cpp index f5140664ca..4a8e90e9f5 100644 --- a/lgc/patch/PatchEntryPointMutate.cpp +++ b/lgc/patch/PatchEntryPointMutate.cpp @@ -57,7 +57,6 @@ #include "ShaderMerger.h" #include "lgc/LgcContext.h" #include "lgc/LgcDialect.h" -#include "lgc/builder/BuilderImpl.h" #include "lgc/patch/ShaderInputs.h" #include "lgc/state/AbiMetadata.h" #include "lgc/state/AbiUnlinked.h" @@ -75,7 +74,6 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include #define DEBUG_TYPE "lgc-patch-entry-point-mutate" @@ -180,7 +178,9 @@ bool PatchEntryPointMutate::runImpl(Module &module, PipelineShadersResult &pipel m_cpsShaderInputCache.clear(); - processGroupMemcpy(module); + if (!m_pipelineState->isGraphics()) + processCsGroupMemcpy(module); + return true; } @@ -261,10 +261,10 @@ static Value *mergeDwordsIntoVector(IRBuilder<> &builder, ArrayRef inpu // ===================================================================================================================== // Lower GroupMemcpyOp -void PatchEntryPointMutate::processGroupMemcpy(Module &module) { +void PatchEntryPointMutate::processCsGroupMemcpy(Module &module) { SmallVector toBeErased; struct Payload { - SmallVectorImpl &tobeErased; + SmallVectorImpl &toBeErased; PatchEntryPointMutate *self; }; Payload payload = {toBeErased, this}; @@ -272,188 +272,77 @@ void PatchEntryPointMutate::processGroupMemcpy(Module &module) { static auto visitor = llvm_dialects::VisitorBuilder() .setStrategy(llvm_dialects::VisitorStrategy::ByFunctionDeclaration) .add([](auto &payload, auto &op) { - payload.self->lowerGroupMemcpy(op); - payload.tobeErased.push_back(&op); + payload.self->lowerCsGroupMemcpy(op); + payload.toBeErased.push_back(&op); }) .build(); visitor.visit(payload, module); - for (auto call : payload.tobeErased) + for (auto call : payload.toBeErased) call->eraseFromParent(); } // ===================================================================================================================== // Lower GroupMemcpyOp - Copy memory using threads in a workgroup (scope=2) or subgroup (scope=3). -void PatchEntryPointMutate::lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp) { - BuilderImpl builder(m_pipelineState); +void PatchEntryPointMutate::lowerCsGroupMemcpy(GroupMemcpyOp &groupMemcpyOp) { + BuilderBase builder(groupMemcpyOp.getContext()); Function *entryPoint = groupMemcpyOp.getFunction(); - ShaderStage stage = getShaderStage(entryPoint); - builder.setShaderStage(stage); builder.SetInsertPoint(&groupMemcpyOp); - auto gfxIp = m_pipelineState->getTargetInfo().getGfxIpVersion(); - - auto dst = groupMemcpyOp.getDst(); - auto src = groupMemcpyOp.getSrc(); - auto len = groupMemcpyOp.getSize(); - auto scope = groupMemcpyOp.getScope(); - unsigned scopeSize = 0; Value *threadIndex = nullptr; + auto scope = groupMemcpyOp.getScope(); if (scope == 2) { unsigned workgroupSize[3] = {}; auto shaderModes = m_pipelineState->getShaderModes(); - if (stage == ShaderStageTask || stage == ShaderStageCompute) { - Module &module = *groupMemcpyOp.getModule(); - workgroupSize[0] = shaderModes->getComputeShaderMode(module).workgroupSizeX; - workgroupSize[1] = shaderModes->getComputeShaderMode(module).workgroupSizeY; - workgroupSize[2] = shaderModes->getComputeShaderMode(module).workgroupSizeZ; - } else if (stage == ShaderStageMesh) { - workgroupSize[0] = shaderModes->getMeshShaderMode().workgroupSizeX; - workgroupSize[1] = shaderModes->getMeshShaderMode().workgroupSizeY; - workgroupSize[2] = shaderModes->getMeshShaderMode().workgroupSizeZ; - } else { - llvm_unreachable("Invalid shade stage!"); - } - - // LocalInvocationId is a function argument now and CreateReadBuiltInInput cannot retrieve it. - unsigned argIndex = 0xFFFFFFFF; - switch (stage) { - case ShaderStageTask: { - auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageTask)->entryArgIdxs.task; - argIndex = entryArgIdxs.localInvocationId; - break; - } - case ShaderStageMesh: { - auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageMesh)->entryArgIdxs.mesh; - argIndex = entryArgIdxs.localInvocationId; - break; - } - case ShaderStageCompute: { - auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageCompute)->entryArgIdxs.cs; - argIndex = entryArgIdxs.localInvocationId; - break; - } - default: - llvm_unreachable("Invalid shade stage!"); - break; - } + assert(getShaderStage(entryPoint) == ShaderStageCompute); - const unsigned waveSize = m_pipelineState->getShaderWaveSize(stage); - - // For mesh shader the following two ids are required. - Value *waveIdInSubgroupMesh = nullptr; - Value *threadIdInWaveMesh = nullptr; - if (stage == ShaderStageMesh) { - builder.CreateIntrinsic(Intrinsic::amdgcn_init_exec, {}, builder.getInt64(-1)); - // waveId = mergedWaveInfo[27:24] - Value *mergedWaveInfo = - getFunctionArgument(entryPoint, ShaderMerger::getSpecialSgprInputIndex(gfxIp, EsGs::MergedWaveInfo)); - waveIdInSubgroupMesh = builder.CreateAnd(builder.CreateLShr(mergedWaveInfo, 24), 0xF, "waveIdInSubgroupMesh"); - - threadIdInWaveMesh = - builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {}, {builder.getInt32(-1), builder.getInt32(0)}); - if (waveSize == 64) { - threadIdInWaveMesh = - builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_hi, {}, {builder.getInt32(-1), threadIdInWaveMesh}); - } - threadIdInWaveMesh->setName("threadIdInWaveMesh"); - } + Module &module = *groupMemcpyOp.getModule(); + workgroupSize[0] = shaderModes->getComputeShaderMode(module).workgroupSizeX; + workgroupSize[1] = shaderModes->getComputeShaderMode(module).workgroupSizeY; + workgroupSize[2] = shaderModes->getComputeShaderMode(module).workgroupSizeZ; - unsigned workgroupTotalSize = workgroupSize[0] * workgroupSize[1] * workgroupSize[2]; + scopeSize = workgroupSize[0] * workgroupSize[1] * workgroupSize[2]; - scopeSize = workgroupTotalSize; + auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageCompute)->entryArgIdxs.cs; + Value *threadIdInGroup = getFunctionArgument(entryPoint, entryArgIdxs.localInvocationId); + Value *threadIdComp[3]; - // localInvocationId argument for mesh shader is available from GFX11+. But it can be retrieved in anther way. - if (stage == ShaderStageMesh) { - threadIndex = builder.CreateAdd(builder.CreateMul(waveIdInSubgroupMesh, builder.getInt32(waveSize)), - threadIdInWaveMesh, "threadIdInSubgroupMesh"); + auto gfxIp = m_pipelineState->getTargetInfo().getGfxIpVersion(); + if (gfxIp.major < 11) { + for (unsigned idx = 0; idx < 3; idx++) + threadIdComp[idx] = builder.CreateExtractElement(threadIdInGroup, idx); } else { - Value *threadIdInGroup = getFunctionArgument(entryPoint, argIndex); - Value *threadIdComp[3]; - if (gfxIp.major < 11) { - for (unsigned idx = 0; idx < 3; idx++) - threadIdComp[idx] = builder.CreateExtractElement(threadIdInGroup, idx); - } else { - // The local invocation ID is packed to VGPR0 on GFX11+ with the following layout: - // - // +-----------------------+-----------------------+-----------------------+ - // | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID Z | - // | [29:20] | [19:10] | [9:0] | - // +-----------------------+-----------------------+-----------------------+ - // localInvocationIdZ = localInvocationId[29:20] - threadIdComp[2] = builder.CreateAnd(builder.CreateLShr(threadIdInGroup, 20), 0x3FF, "localInvocationIdZ"); - // localInvocationIdY = localInvocationId[19:10] - threadIdComp[1] = builder.CreateAnd(builder.CreateLShr(threadIdInGroup, 10), 0x3FF, "localInvocationIdY"); - // localInvocationIdX = localInvocationId[9:0] - threadIdComp[0] = builder.CreateAnd(threadIdInGroup, 0x3FF, "localInvocationIdX"); - } + // The local invocation ID is packed to VGPR0 on GFX11+ with the following layout: + // + // +-----------------------+-----------------------+-----------------------+ + // | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID X | + // | [29:20] | [19:10] | [9:0] | + // +-----------------------+-----------------------+-----------------------+ + // localInvocationIdZ = localInvocationId[29:20] + threadIdComp[2] = builder.CreateAnd(builder.CreateLShr(threadIdInGroup, 20), 0x3FF, "localInvocationIdZ"); + // localInvocationIdY = localInvocationId[19:10] + threadIdComp[1] = builder.CreateAnd(builder.CreateLShr(threadIdInGroup, 10), 0x3FF, "localInvocationIdY"); + // localInvocationIdX = localInvocationId[9:0] + threadIdComp[0] = builder.CreateAnd(threadIdInGroup, 0x3FF, "localInvocationIdX"); + } - // LocalInvocationIndex is - // (LocalInvocationId.Z * WorkgroupSize.Y + LocalInvocationId.Y) * WorkGroupSize.X + LocalInvocationId.X + // LocalInvocationIndex is + // (LocalInvocationId.Z * WorkgroupSize.Y + LocalInvocationId.Y) * WorkGroupSize.X + LocalInvocationId.X + // tidigCompCnt is not always set to 2(xyz) if groupSizeY and/or groupSizeZ are 1. See RegisterMetadataBuilder.cpp. + threadIndex = builder.getInt32(0); + if (workgroupSize[2] > 1) threadIndex = builder.CreateMul(threadIdComp[2], builder.getInt32(workgroupSize[1])); + if (workgroupSize[1] > 1) { threadIndex = builder.CreateAdd(threadIndex, threadIdComp[1]); threadIndex = builder.CreateMul(threadIndex, builder.getInt32(workgroupSize[0])); - threadIndex = builder.CreateAdd(threadIndex, threadIdComp[0]); } + threadIndex = builder.CreateAdd(threadIndex, threadIdComp[0]); } else { llvm_unreachable("Unsupported scope!"); } - // Copy in 16-bytes if possible - unsigned wideDwords = 4; - // If either pointer is in LDS, copy in 8-bytes - if (src->getType()->getPointerAddressSpace() == ADDR_SPACE_LOCAL || - dst->getType()->getPointerAddressSpace() == ADDR_SPACE_LOCAL) - wideDwords = 2; - - unsigned baseOffset = 0; - - auto copyFunc = [&](Type *copyTy, unsigned copySize) { - Value *offset = - builder.CreateAdd(builder.getInt32(baseOffset), builder.CreateMul(threadIndex, builder.getInt32(copySize))); - Value *dstPtr = builder.CreateGEP(builder.getInt8Ty(), dst, offset); - Value *srcPtr = builder.CreateGEP(builder.getInt8Ty(), src, offset); - Value *data = builder.CreateLoad(copyTy, srcPtr); - builder.CreateStore(data, dstPtr); - }; - - unsigned wideDwordsCopySize = sizeof(unsigned) * wideDwords; - Type *wideDwordsTy = ArrayType::get(builder.getInt32Ty(), wideDwords); - while (baseOffset + wideDwordsCopySize * scopeSize <= len) { - copyFunc(wideDwordsTy, wideDwordsCopySize); - baseOffset += wideDwordsCopySize * scopeSize; - } - - unsigned dwordCopySize = sizeof(unsigned); - Type *dwordTy = builder.getInt32Ty(); - while (baseOffset + dwordCopySize * scopeSize <= len) { - copyFunc(dwordTy, dwordCopySize); - baseOffset += dwordCopySize * scopeSize; - } - - unsigned remainingBytes = len - baseOffset; - - if (remainingBytes) { - assert(remainingBytes % 4 == 0); - BasicBlock *afterBlock = groupMemcpyOp.getParent(); - BasicBlock *beforeBlock = splitBlockBefore(afterBlock, &groupMemcpyOp, nullptr, nullptr, nullptr); - beforeBlock->takeName(afterBlock); - afterBlock->setName(Twine(beforeBlock->getName()) + ".afterGroupMemcpyTail"); - - // Split to create a tail copy block, empty except for an unconditional branch to afterBlock. - BasicBlock *copyBlock = splitBlockBefore(afterBlock, &groupMemcpyOp, nullptr, nullptr, nullptr, ".groupMemcpyTail"); - // Change the branch at the end of beforeBlock to be conditional. - beforeBlock->getTerminator()->eraseFromParent(); - builder.SetInsertPoint(beforeBlock); - - Value *indexInRange = builder.CreateICmpULT(threadIndex, builder.getInt32(remainingBytes / 4)); - - builder.CreateCondBr(indexInRange, copyBlock, afterBlock); - // Create the copy instructions. - builder.SetInsertPoint(copyBlock->getTerminator()); - copyFunc(dwordTy, dwordCopySize); - } + commonProcessGroupMemcpy(groupMemcpyOp, builder, threadIndex, scopeSize); } // =====================================================================================================================