Skip to content

Commit

Permalink
lgc: Move handling of GroupMemcpy for mesh/task shaders to MeshTaskSh…
Browse files Browse the repository at this point in the history
…ader
  • Loading branch information
xazhangAMD committed Nov 11, 2023
1 parent 4320b9c commit 98f6e63
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 167 deletions.
5 changes: 5 additions & 0 deletions lgc/include/lgc/patch/Patch.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class PassBuilder;

namespace lgc {

class BuilderBase;
class GroupMemcpyOp;
class PipelineState;
class PassManager;

Expand All @@ -63,6 +65,9 @@ class Patch {

static llvm::GlobalVariable *getLdsVariable(PipelineState *pipelineState, llvm::Module *module);

static void commonProcessGroupMemcpy(GroupMemcpyOp &groupMemcpyOp, lgc::BuilderBase &builder,
llvm::Value *threadIndex, unsigned scopeSize);

protected:
static void addOptimizationPasses(lgc::PassManager &passMgr, uint32_t optLevel);

Expand Down
4 changes: 2 additions & 2 deletions lgc/include/lgc/patch/PatchEntryPointMutate.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ class PatchEntryPointMutate : public Patch, public llvm::PassInfoMixin<PatchEntr

bool isComputeWithCalls() const;

void processGroupMemcpy(llvm::Module &module);
void lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp);
void processCsGroupMemcpy(llvm::Module &module);
void lowerCsGroupMemcpy(GroupMemcpyOp &groupMemcpyOp);

bool m_hasTs; // Whether the pipeline has tessllation shader
bool m_hasGs; // Whether the pipeline has geometry shader
Expand Down
41 changes: 40 additions & 1 deletion lgc/patch/MeshTaskShader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ void MeshTaskShader::processTaskShader(Function *entryPoint) {

static auto visitor = llvm_dialects::VisitorBuilder<MeshTaskShader>()
.setStrategy(llvm_dialects::VisitorStrategy::ByFunctionDeclaration)
.add<GroupMemcpyOp>(&MeshTaskShader::lowerGroupMemcpy)
.add<TaskPayloadPtrOp>(&MeshTaskShader::lowerTaskPayloadPtr)
.add<EmitMeshTasksOp>(&MeshTaskShader::lowerEmitMeshTasks)
.build();
Expand Down Expand Up @@ -919,6 +920,44 @@ void MeshTaskShader::processMeshShader(Function *entryPoint) {
}
}

// =====================================================================================================================
// Lower GroupMemcpyOp - copy memory using all threads in a workgroup.
void MeshTaskShader::lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp) {
Function *entryPoint = groupMemcpyOp.getFunction();
ShaderStage stage = getShaderStage(entryPoint);
m_builder.SetInsertPoint(&groupMemcpyOp);

unsigned scopeSize = 0;
Value *threadIndex = nullptr;

auto scope = groupMemcpyOp.getScope();
if (scope == 2) {
unsigned workgroupSize[3] = {};
auto shaderModes = m_pipelineState->getShaderModes();
if (stage == ShaderStageTask) {
Module &module = *groupMemcpyOp.getModule();
workgroupSize[0] = shaderModes->getComputeShaderMode(module).workgroupSizeX;
workgroupSize[1] = shaderModes->getComputeShaderMode(module).workgroupSizeY;
workgroupSize[2] = shaderModes->getComputeShaderMode(module).workgroupSizeZ;
} else if (stage == ShaderStageMesh) {
workgroupSize[0] = shaderModes->getMeshShaderMode().workgroupSizeX;
workgroupSize[1] = shaderModes->getMeshShaderMode().workgroupSizeY;
workgroupSize[2] = shaderModes->getMeshShaderMode().workgroupSizeZ;
} else {
llvm_unreachable("Invalid shade stage!");
}

scopeSize = workgroupSize[0] * workgroupSize[1] * workgroupSize[2];
threadIndex = m_waveThreadInfo.threadIdInSubgroup;
} else {
llvm_unreachable("Unsupported scope!");
}

Patch::commonProcessGroupMemcpy(groupMemcpyOp, m_builder, threadIndex, scopeSize);

m_callsToRemove.push_back(&groupMemcpyOp);
}

// =====================================================================================================================
// Lower task payload pointer to buffer fat pointer.
//
Expand Down Expand Up @@ -2422,7 +2461,7 @@ Value *MeshTaskShader::getMeshLocalInvocationId() {
// The local invocation ID is packed to VGPR0 on GFX11+ with the following layout:
//
// +-----------------------+-----------------------+-----------------------+
// | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID Z |
// | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID X |
// | [29:20] | [19:10] | [9:0] |
// +-----------------------+-----------------------+-----------------------+
auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageMesh)->entryArgIdxs.mesh;
Expand Down
1 change: 1 addition & 0 deletions lgc/patch/MeshTaskShader.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class MeshTaskShader {
void processTaskShader(llvm::Function *entryPoint);
void processMeshShader(llvm::Function *entryPoint);

void lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp);
void lowerTaskPayloadPtr(TaskPayloadPtrOp &taskPayloadPtrOp);
void lowerEmitMeshTasks(EmitMeshTasksOp &emitMeshTasksOp);
void lowerSetMeshOutputs(SetMeshOutputsOp &setMeshOutputsOp);
Expand Down
65 changes: 65 additions & 0 deletions lgc/patch/Patch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
#include "llvm/Transforms/Scalar/SimplifyCFG.h"
#include "llvm/Transforms/Scalar/SpeculativeExecution.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Mem2Reg.h"

#define DEBUG_TYPE "lgc-patch"
Expand Down Expand Up @@ -459,4 +460,68 @@ GlobalVariable *Patch::getLdsVariable(PipelineState *pipelineState, Module *modu
return lds;
}

// =====================================================================================================================
// Common code to be shared to implement GroupMemcpyOp for MeshTaskShader and PatchEntryPointMutate.
void Patch::commonProcessGroupMemcpy(GroupMemcpyOp &groupMemcpyOp, lgc::BuilderBase &builder, llvm::Value *threadIndex,
unsigned scopeSize) {
auto dst = groupMemcpyOp.getDst();
auto src = groupMemcpyOp.getSrc();
auto len = groupMemcpyOp.getSize();

// Copy in 16-bytes if possible
unsigned wideDwords = 4;
// If either pointer is in LDS, copy in 8-bytes
if (src->getType()->getPointerAddressSpace() == ADDR_SPACE_LOCAL ||
dst->getType()->getPointerAddressSpace() == ADDR_SPACE_LOCAL)
wideDwords = 2;

unsigned baseOffset = 0;

auto copyFunc = [&](Type *copyTy, unsigned copySize) {
Value *offset =
builder.CreateAdd(builder.getInt32(baseOffset), builder.CreateMul(threadIndex, builder.getInt32(copySize)));
Value *dstPtr = builder.CreateGEP(builder.getInt8Ty(), dst, offset);
Value *srcPtr = builder.CreateGEP(builder.getInt8Ty(), src, offset);
Value *data = builder.CreateLoad(copyTy, srcPtr);
builder.CreateStore(data, dstPtr);
};

unsigned wideDwordsCopySize = sizeof(unsigned) * wideDwords;
Type *wideDwordsTy = ArrayType::get(builder.getInt32Ty(), wideDwords);
while (baseOffset + wideDwordsCopySize * scopeSize <= len) {
copyFunc(wideDwordsTy, wideDwordsCopySize);
baseOffset += wideDwordsCopySize * scopeSize;
}

unsigned dwordCopySize = sizeof(unsigned);
Type *dwordTy = builder.getInt32Ty();
while (baseOffset + dwordCopySize * scopeSize <= len) {
copyFunc(dwordTy, dwordCopySize);
baseOffset += dwordCopySize * scopeSize;
}

unsigned remainingBytes = len - baseOffset;

if (remainingBytes) {
assert(remainingBytes % 4 == 0);
BasicBlock *afterBlock = groupMemcpyOp.getParent();
BasicBlock *beforeBlock = splitBlockBefore(afterBlock, &groupMemcpyOp, nullptr, nullptr, nullptr);
beforeBlock->takeName(afterBlock);
afterBlock->setName(Twine(beforeBlock->getName()) + ".afterGroupMemcpyTail");

// Split to create a tail copy block, empty except for an unconditional branch to afterBlock.
BasicBlock *copyBlock = splitBlockBefore(afterBlock, &groupMemcpyOp, nullptr, nullptr, nullptr, ".groupMemcpyTail");
// Change the branch at the end of beforeBlock to be conditional.
beforeBlock->getTerminator()->eraseFromParent();
builder.SetInsertPoint(beforeBlock);

Value *indexInRange = builder.CreateICmpULT(threadIndex, builder.getInt32(remainingBytes / 4));

builder.CreateCondBr(indexInRange, copyBlock, afterBlock);
// Create the copy instructions.
builder.SetInsertPoint(copyBlock->getTerminator());
copyFunc(dwordTy, dwordCopySize);
}
}

} // namespace lgc
Loading

0 comments on commit 98f6e63

Please sign in to comment.