Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lgc: Move handling of GroupMemcpy for mesh/task shaders to MeshTaskShader #2814

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lgc/include/lgc/patch/Patch.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class PassBuilder;

namespace lgc {

class BuilderBase;
class GroupMemcpyOp;
class PipelineState;
class PassManager;

Expand All @@ -63,6 +65,9 @@ class Patch {

static llvm::GlobalVariable *getLdsVariable(PipelineState *pipelineState, llvm::Module *module);

static void commonProcessGroupMemcpy(GroupMemcpyOp &groupMemcpyOp, lgc::BuilderBase &builder,
llvm::Value *threadIndex, unsigned scopeSize);

protected:
static void addOptimizationPasses(lgc::PassManager &passMgr, uint32_t optLevel);

Expand Down
4 changes: 2 additions & 2 deletions lgc/include/lgc/patch/PatchEntryPointMutate.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ class PatchEntryPointMutate : public Patch, public llvm::PassInfoMixin<PatchEntr

bool isComputeWithCalls() const;

void processGroupMemcpy(llvm::Module &module);
void lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp);
void processCsGroupMemcpy(llvm::Module &module);
void lowerCsGroupMemcpy(GroupMemcpyOp &groupMemcpyOp);

bool m_hasTs; // Whether the pipeline has tessllation shader
bool m_hasGs; // Whether the pipeline has geometry shader
Expand Down
41 changes: 40 additions & 1 deletion lgc/patch/MeshTaskShader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ void MeshTaskShader::processTaskShader(Function *entryPoint) {

static auto visitor = llvm_dialects::VisitorBuilder<MeshTaskShader>()
.setStrategy(llvm_dialects::VisitorStrategy::ByFunctionDeclaration)
.add<GroupMemcpyOp>(&MeshTaskShader::lowerGroupMemcpy)
.add<TaskPayloadPtrOp>(&MeshTaskShader::lowerTaskPayloadPtr)
.add<EmitMeshTasksOp>(&MeshTaskShader::lowerEmitMeshTasks)
.build();
Expand Down Expand Up @@ -919,6 +920,44 @@ void MeshTaskShader::processMeshShader(Function *entryPoint) {
}
}

// =====================================================================================================================
// Lower GroupMemcpyOp - copy memory using all threads in a workgroup.
void MeshTaskShader::lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp) {
Function *entryPoint = groupMemcpyOp.getFunction();
ShaderStage stage = getShaderStage(entryPoint);
m_builder.SetInsertPoint(&groupMemcpyOp);

unsigned scopeSize = 0;
Value *threadIndex = nullptr;

auto scope = groupMemcpyOp.getScope();
if (scope == 2) {
unsigned workgroupSize[3] = {};
auto shaderModes = m_pipelineState->getShaderModes();
if (stage == ShaderStageTask) {
Module &module = *groupMemcpyOp.getModule();
workgroupSize[0] = shaderModes->getComputeShaderMode(module).workgroupSizeX;
workgroupSize[1] = shaderModes->getComputeShaderMode(module).workgroupSizeY;
workgroupSize[2] = shaderModes->getComputeShaderMode(module).workgroupSizeZ;
} else if (stage == ShaderStageMesh) {
workgroupSize[0] = shaderModes->getMeshShaderMode().workgroupSizeX;
workgroupSize[1] = shaderModes->getMeshShaderMode().workgroupSizeY;
workgroupSize[2] = shaderModes->getMeshShaderMode().workgroupSizeZ;
} else {
llvm_unreachable("Invalid shade stage!");
}

scopeSize = workgroupSize[0] * workgroupSize[1] * workgroupSize[2];
threadIndex = m_waveThreadInfo.threadIdInSubgroup;
} else {
llvm_unreachable("Unsupported scope!");
}

Patch::commonProcessGroupMemcpy(groupMemcpyOp, m_builder, threadIndex, scopeSize);

m_callsToRemove.push_back(&groupMemcpyOp);
}

// =====================================================================================================================
// Lower task payload pointer to buffer fat pointer.
//
Expand Down Expand Up @@ -2422,7 +2461,7 @@ Value *MeshTaskShader::getMeshLocalInvocationId() {
// The local invocation ID is packed to VGPR0 on GFX11+ with the following layout:
//
// +-----------------------+-----------------------+-----------------------+
// | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID Z |
// | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID X |
// | [29:20] | [19:10] | [9:0] |
// +-----------------------+-----------------------+-----------------------+
auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageMesh)->entryArgIdxs.mesh;
Expand Down
1 change: 1 addition & 0 deletions lgc/patch/MeshTaskShader.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class MeshTaskShader {
void processTaskShader(llvm::Function *entryPoint);
void processMeshShader(llvm::Function *entryPoint);

void lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp);
void lowerTaskPayloadPtr(TaskPayloadPtrOp &taskPayloadPtrOp);
void lowerEmitMeshTasks(EmitMeshTasksOp &emitMeshTasksOp);
void lowerSetMeshOutputs(SetMeshOutputsOp &setMeshOutputsOp);
Expand Down
65 changes: 65 additions & 0 deletions lgc/patch/Patch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
#include "llvm/Transforms/Scalar/SimplifyCFG.h"
#include "llvm/Transforms/Scalar/SpeculativeExecution.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Mem2Reg.h"

#define DEBUG_TYPE "lgc-patch"
Expand Down Expand Up @@ -459,4 +460,68 @@ GlobalVariable *Patch::getLdsVariable(PipelineState *pipelineState, Module *modu
return lds;
}

// =====================================================================================================================
// Common code to be shared to implement GroupMemcpyOp for MeshTaskShader and PatchEntryPointMutate.
void Patch::commonProcessGroupMemcpy(GroupMemcpyOp &groupMemcpyOp, lgc::BuilderBase &builder, llvm::Value *threadIndex,
unsigned scopeSize) {
auto dst = groupMemcpyOp.getDst();
auto src = groupMemcpyOp.getSrc();
auto len = groupMemcpyOp.getSize();

// Copy in 16-bytes if possible
unsigned wideDwords = 4;
// If either pointer is in LDS, copy in 8-bytes
if (src->getType()->getPointerAddressSpace() == ADDR_SPACE_LOCAL ||
dst->getType()->getPointerAddressSpace() == ADDR_SPACE_LOCAL)
wideDwords = 2;

unsigned baseOffset = 0;

auto copyFunc = [&](Type *copyTy, unsigned copySize) {
Value *offset =
builder.CreateAdd(builder.getInt32(baseOffset), builder.CreateMul(threadIndex, builder.getInt32(copySize)));
Value *dstPtr = builder.CreateGEP(builder.getInt8Ty(), dst, offset);
Value *srcPtr = builder.CreateGEP(builder.getInt8Ty(), src, offset);
Value *data = builder.CreateLoad(copyTy, srcPtr);
builder.CreateStore(data, dstPtr);
};

unsigned wideDwordsCopySize = sizeof(unsigned) * wideDwords;
Type *wideDwordsTy = ArrayType::get(builder.getInt32Ty(), wideDwords);
while (baseOffset + wideDwordsCopySize * scopeSize <= len) {
copyFunc(wideDwordsTy, wideDwordsCopySize);
baseOffset += wideDwordsCopySize * scopeSize;
}

unsigned dwordCopySize = sizeof(unsigned);
Type *dwordTy = builder.getInt32Ty();
while (baseOffset + dwordCopySize * scopeSize <= len) {
copyFunc(dwordTy, dwordCopySize);
baseOffset += dwordCopySize * scopeSize;
}

unsigned remainingBytes = len - baseOffset;

if (remainingBytes) {
assert(remainingBytes % 4 == 0);
BasicBlock *afterBlock = groupMemcpyOp.getParent();
BasicBlock *beforeBlock = splitBlockBefore(afterBlock, &groupMemcpyOp, nullptr, nullptr, nullptr);
beforeBlock->takeName(afterBlock);
afterBlock->setName(Twine(beforeBlock->getName()) + ".afterGroupMemcpyTail");

// Split to create a tail copy block, empty except for an unconditional branch to afterBlock.
BasicBlock *copyBlock = splitBlockBefore(afterBlock, &groupMemcpyOp, nullptr, nullptr, nullptr, ".groupMemcpyTail");
// Change the branch at the end of beforeBlock to be conditional.
beforeBlock->getTerminator()->eraseFromParent();
builder.SetInsertPoint(beforeBlock);

Value *indexInRange = builder.CreateICmpULT(threadIndex, builder.getInt32(remainingBytes / 4));

builder.CreateCondBr(indexInRange, copyBlock, afterBlock);
// Create the copy instructions.
builder.SetInsertPoint(copyBlock->getTerminator());
copyFunc(dwordTy, dwordCopySize);
}
}

} // namespace lgc
Loading
Loading