Skip to content

Commit

Permalink
lgc: Move handling of GroupMemcpy for mesh/task shaders to MeshTaskSh…
Browse files Browse the repository at this point in the history
…ader
  • Loading branch information
xazhangAMD committed Nov 10, 2023
1 parent 82ef966 commit bd8a42b
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 159 deletions.
5 changes: 5 additions & 0 deletions lgc/include/lgc/patch/Patch.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class PassBuilder;

namespace lgc {

class BuilderBase;
class GroupMemcpyOp;
class PipelineState;
class PassManager;

Expand All @@ -63,6 +65,9 @@ class Patch {

static llvm::GlobalVariable *getLdsVariable(PipelineState *pipelineState, llvm::Module *module);

static void commonProcessGroupMemcpy(GroupMemcpyOp &groupMemcpyOp, lgc::BuilderBase& builder,
llvm::Value *threadIndex, unsigned scopeSize);

protected:
static void addOptimizationPasses(lgc::PassManager &passMgr, uint32_t optLevel);

Expand Down
4 changes: 2 additions & 2 deletions lgc/include/lgc/patch/PatchEntryPointMutate.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ class PatchEntryPointMutate : public Patch, public llvm::PassInfoMixin<PatchEntr

bool isComputeWithCalls() const;

void processGroupMemcpy(llvm::Module &module);
void lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp);
void processCsGroupMemcpy(llvm::Module &module);
void lowerCsGroupMemcpy(GroupMemcpyOp &groupMemcpyOp);

bool m_hasTs; // Whether the pipeline has tessllation shader
bool m_hasGs; // Whether the pipeline has geometry shader
Expand Down
41 changes: 40 additions & 1 deletion lgc/patch/MeshTaskShader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ void MeshTaskShader::processTaskShader(Function *entryPoint) {

static auto visitor = llvm_dialects::VisitorBuilder<MeshTaskShader>()
.setStrategy(llvm_dialects::VisitorStrategy::ByFunctionDeclaration)
.add<GroupMemcpyOp>(&MeshTaskShader::lowerGroupMemcpy)
.add<TaskPayloadPtrOp>(&MeshTaskShader::lowerTaskPayloadPtr)
.add<EmitMeshTasksOp>(&MeshTaskShader::lowerEmitMeshTasks)
.build();
Expand Down Expand Up @@ -919,6 +920,44 @@ void MeshTaskShader::processMeshShader(Function *entryPoint) {
}
}

// =====================================================================================================================
// Lower GroupMemcpyOp - copy memory using all threads in a workgroup.
void MeshTaskShader::lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp) {
Function *entryPoint = groupMemcpyOp.getFunction();
ShaderStage stage = getShaderStage(entryPoint);
m_builder.SetInsertPoint(&groupMemcpyOp);

unsigned scopeSize = 0;
Value *threadIndex = nullptr;

auto scope = groupMemcpyOp.getScope();
if (scope == 2) {
unsigned workgroupSize[3] = {};
auto shaderModes = m_pipelineState->getShaderModes();
if (stage == ShaderStageTask) {
Module &module = *groupMemcpyOp.getModule();
workgroupSize[0] = shaderModes->getComputeShaderMode(module).workgroupSizeX;
workgroupSize[1] = shaderModes->getComputeShaderMode(module).workgroupSizeY;
workgroupSize[2] = shaderModes->getComputeShaderMode(module).workgroupSizeZ;
} else if (stage == ShaderStageMesh) {
workgroupSize[0] = shaderModes->getMeshShaderMode().workgroupSizeX;
workgroupSize[1] = shaderModes->getMeshShaderMode().workgroupSizeY;
workgroupSize[2] = shaderModes->getMeshShaderMode().workgroupSizeZ;
} else {
llvm_unreachable("Invalid shade stage!");
}

scopeSize = workgroupSize[0] * workgroupSize[1] * workgroupSize[2];
threadIndex = m_waveThreadInfo.threadIdInSubgroup;
} else {
llvm_unreachable("Unsupported scope!");
}

Patch::commonProcessGroupMemcpy(groupMemcpyOp, m_builder, threadIndex, scopeSize);

m_callsToRemove.push_back(&groupMemcpyOp);
}

// =====================================================================================================================
// Lower task payload pointer to buffer fat pointer.
//
Expand Down Expand Up @@ -2422,7 +2461,7 @@ Value *MeshTaskShader::getMeshLocalInvocationId() {
// The local invocation ID is packed to VGPR0 on GFX11+ with the following layout:
//
// +-----------------------+-----------------------+-----------------------+
// | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID Z |
// | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID X |
// | [29:20] | [19:10] | [9:0] |
// +-----------------------+-----------------------+-----------------------+
auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageMesh)->entryArgIdxs.mesh;
Expand Down
1 change: 1 addition & 0 deletions lgc/patch/MeshTaskShader.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class MeshTaskShader {
void processTaskShader(llvm::Function *entryPoint);
void processMeshShader(llvm::Function *entryPoint);

void lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp);
void lowerTaskPayloadPtr(TaskPayloadPtrOp &taskPayloadPtrOp);
void lowerEmitMeshTasks(EmitMeshTasksOp &emitMeshTasksOp);
void lowerSetMeshOutputs(SetMeshOutputsOp &setMeshOutputsOp);
Expand Down
201 changes: 45 additions & 156 deletions lgc/patch/PatchEntryPointMutate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
#include "ShaderMerger.h"
#include "lgc/LgcContext.h"
#include "lgc/LgcDialect.h"
#include "lgc/builder/BuilderImpl.h"
#include "lgc/patch/ShaderInputs.h"
#include "lgc/state/AbiMetadata.h"
#include "lgc/state/AbiUnlinked.h"
Expand All @@ -75,7 +74,6 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <optional>

#define DEBUG_TYPE "lgc-patch-entry-point-mutate"
Expand Down Expand Up @@ -180,7 +178,9 @@ bool PatchEntryPointMutate::runImpl(Module &module, PipelineShadersResult &pipel

m_cpsShaderInputCache.clear();

processGroupMemcpy(module);
if (!m_pipelineState->isGraphics())
processCsGroupMemcpy(module);

return true;
}

Expand Down Expand Up @@ -261,199 +261,88 @@ static Value *mergeDwordsIntoVector(IRBuilder<> &builder, ArrayRef<Value *> inpu

// =====================================================================================================================
// Lower GroupMemcpyOp
void PatchEntryPointMutate::processGroupMemcpy(Module &module) {
void PatchEntryPointMutate::processCsGroupMemcpy(Module &module) {
SmallVector<CallInst *> toBeErased;
struct Payload {
SmallVectorImpl<CallInst *> &tobeErased;
SmallVectorImpl<CallInst *> &toBeErased;
PatchEntryPointMutate *self;
};
Payload payload = {toBeErased, this};

static auto visitor = llvm_dialects::VisitorBuilder<Payload>()
.setStrategy(llvm_dialects::VisitorStrategy::ByFunctionDeclaration)
.add<GroupMemcpyOp>([](auto &payload, auto &op) {
payload.self->lowerGroupMemcpy(op);
payload.tobeErased.push_back(&op);
payload.self->lowerCsGroupMemcpy(op);
payload.toBeErased.push_back(&op);
})
.build();
visitor.visit(payload, module);
for (auto call : payload.tobeErased)
for (auto call : payload.toBeErased)
call->eraseFromParent();
}

// =====================================================================================================================
// Lower GroupMemcpyOp - Copy memory using threads in a workgroup (scope=2) or subgroup (scope=3).
void PatchEntryPointMutate::lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp) {
BuilderImpl builder(m_pipelineState);
void PatchEntryPointMutate::lowerCsGroupMemcpy(GroupMemcpyOp &groupMemcpyOp) {
BuilderBase builder(groupMemcpyOp.getContext());
Function *entryPoint = groupMemcpyOp.getFunction();
ShaderStage stage = getShaderStage(entryPoint);
builder.setShaderStage(stage);
builder.SetInsertPoint(&groupMemcpyOp);

auto gfxIp = m_pipelineState->getTargetInfo().getGfxIpVersion();

auto dst = groupMemcpyOp.getDst();
auto src = groupMemcpyOp.getSrc();
auto len = groupMemcpyOp.getSize();
auto scope = groupMemcpyOp.getScope();

unsigned scopeSize = 0;
Value *threadIndex = nullptr;

auto scope = groupMemcpyOp.getScope();
if (scope == 2) {
unsigned workgroupSize[3] = {};
auto shaderModes = m_pipelineState->getShaderModes();
if (stage == ShaderStageTask || stage == ShaderStageCompute) {
Module &module = *groupMemcpyOp.getModule();
workgroupSize[0] = shaderModes->getComputeShaderMode(module).workgroupSizeX;
workgroupSize[1] = shaderModes->getComputeShaderMode(module).workgroupSizeY;
workgroupSize[2] = shaderModes->getComputeShaderMode(module).workgroupSizeZ;
} else if (stage == ShaderStageMesh) {
workgroupSize[0] = shaderModes->getMeshShaderMode().workgroupSizeX;
workgroupSize[1] = shaderModes->getMeshShaderMode().workgroupSizeY;
workgroupSize[2] = shaderModes->getMeshShaderMode().workgroupSizeZ;
} else {
llvm_unreachable("Invalid shade stage!");
}

// LocalInvocationId is a function argument now and CreateReadBuiltInInput cannot retrieve it.
unsigned argIndex = 0xFFFFFFFF;
switch (stage) {
case ShaderStageTask: {
auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageTask)->entryArgIdxs.task;
argIndex = entryArgIdxs.localInvocationId;
break;
}
case ShaderStageMesh: {
auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageMesh)->entryArgIdxs.mesh;
argIndex = entryArgIdxs.localInvocationId;
break;
}
case ShaderStageCompute: {
auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageCompute)->entryArgIdxs.cs;
argIndex = entryArgIdxs.localInvocationId;
break;
}
default:
llvm_unreachable("Invalid shade stage!");
break;
}
assert(getShaderStage(entryPoint) == ShaderStageCompute);

const unsigned waveSize = m_pipelineState->getShaderWaveSize(stage);

// For mesh shader the following two ids are required.
Value *waveIdInSubgroupMesh = nullptr;
Value *threadIdInWaveMesh = nullptr;
if (stage == ShaderStageMesh) {
builder.CreateIntrinsic(Intrinsic::amdgcn_init_exec, {}, builder.getInt64(-1));
// waveId = mergedWaveInfo[27:24]
Value *mergedWaveInfo =
getFunctionArgument(entryPoint, ShaderMerger::getSpecialSgprInputIndex(gfxIp, EsGs::MergedWaveInfo));
waveIdInSubgroupMesh = builder.CreateAnd(builder.CreateLShr(mergedWaveInfo, 24), 0xF, "waveIdInSubgroupMesh");

threadIdInWaveMesh =
builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {}, {builder.getInt32(-1), builder.getInt32(0)});
if (waveSize == 64) {
threadIdInWaveMesh =
builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_hi, {}, {builder.getInt32(-1), threadIdInWaveMesh});
}
threadIdInWaveMesh->setName("threadIdInWaveMesh");
}
Module &module = *groupMemcpyOp.getModule();
workgroupSize[0] = shaderModes->getComputeShaderMode(module).workgroupSizeX;
workgroupSize[1] = shaderModes->getComputeShaderMode(module).workgroupSizeY;
workgroupSize[2] = shaderModes->getComputeShaderMode(module).workgroupSizeZ;

unsigned workgroupTotalSize = workgroupSize[0] * workgroupSize[1] * workgroupSize[2];
scopeSize = workgroupSize[0] * workgroupSize[1] * workgroupSize[2];

scopeSize = workgroupTotalSize;
auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageCompute)->entryArgIdxs.cs;
Value *threadIdInGroup = getFunctionArgument(entryPoint, entryArgIdxs.localInvocationId);
Value *threadIdComp[3];

// localInvocationId argument for mesh shader is available from GFX11+. But it can be retrieved in anther way.
if (stage == ShaderStageMesh) {
threadIndex = builder.CreateAdd(builder.CreateMul(waveIdInSubgroupMesh, builder.getInt32(waveSize)),
threadIdInWaveMesh, "threadIdInSubgroupMesh");
auto gfxIp = m_pipelineState->getTargetInfo().getGfxIpVersion();
if (gfxIp.major < 11) {
for (unsigned idx = 0; idx < 3; idx++)
threadIdComp[idx] = builder.CreateExtractElement(threadIdInGroup, idx);
} else {
Value *threadIdInGroup = getFunctionArgument(entryPoint, argIndex);
Value *threadIdComp[3];
if (gfxIp.major < 11) {
for (unsigned idx = 0; idx < 3; idx++)
threadIdComp[idx] = builder.CreateExtractElement(threadIdInGroup, idx);
} else {
// The local invocation ID is packed to VGPR0 on GFX11+ with the following layout:
//
// +-----------------------+-----------------------+-----------------------+
// | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID Z |
// | [29:20] | [19:10] | [9:0] |
// +-----------------------+-----------------------+-----------------------+
// localInvocationIdZ = localInvocationId[29:20]
threadIdComp[2] = builder.CreateAnd(builder.CreateLShr(threadIdInGroup, 20), 0x3FF, "localInvocationIdZ");
// localInvocationIdY = localInvocationId[19:10]
threadIdComp[1] = builder.CreateAnd(builder.CreateLShr(threadIdInGroup, 10), 0x3FF, "localInvocationIdY");
// localInvocationIdX = localInvocationId[9:0]
threadIdComp[0] = builder.CreateAnd(threadIdInGroup, 0x3FF, "localInvocationIdX");
}
// The local invocation ID is packed to VGPR0 on GFX11+ with the following layout:
//
// +-----------------------+-----------------------+-----------------------+
// | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID X |
// | [29:20] | [19:10] | [9:0] |
// +-----------------------+-----------------------+-----------------------+
// localInvocationIdZ = localInvocationId[29:20]
threadIdComp[2] = builder.CreateAnd(builder.CreateLShr(threadIdInGroup, 20), 0x3FF, "localInvocationIdZ");
// localInvocationIdY = localInvocationId[19:10]
threadIdComp[1] = builder.CreateAnd(builder.CreateLShr(threadIdInGroup, 10), 0x3FF, "localInvocationIdY");
// localInvocationIdX = localInvocationId[9:0]
threadIdComp[0] = builder.CreateAnd(threadIdInGroup, 0x3FF, "localInvocationIdX");
}

// LocalInvocationIndex is
// (LocalInvocationId.Z * WorkgroupSize.Y + LocalInvocationId.Y) * WorkGroupSize.X + LocalInvocationId.X
// LocalInvocationIndex is
// (LocalInvocationId.Z * WorkgroupSize.Y + LocalInvocationId.Y) * WorkGroupSize.X + LocalInvocationId.X
// tidigCompCnt is not always set to 2(xyz) if groupSizeY and/or groupSizeZ are 1. See RegisterMetadataBuilder.cpp.
threadIndex = builder.getInt32(0);
if (workgroupSize[2] > 1)
threadIndex = builder.CreateMul(threadIdComp[2], builder.getInt32(workgroupSize[1]));
if (workgroupSize[1] > 1) {
threadIndex = builder.CreateAdd(threadIndex, threadIdComp[1]);
threadIndex = builder.CreateMul(threadIndex, builder.getInt32(workgroupSize[0]));
threadIndex = builder.CreateAdd(threadIndex, threadIdComp[0]);
}
threadIndex = builder.CreateAdd(threadIndex, threadIdComp[0]);
} else {
llvm_unreachable("Unsupported scope!");
}

// Copy in 16-bytes if possible
unsigned wideDwords = 4;
// If either pointer is in LDS, copy in 8-bytes
if (src->getType()->getPointerAddressSpace() == ADDR_SPACE_LOCAL ||
dst->getType()->getPointerAddressSpace() == ADDR_SPACE_LOCAL)
wideDwords = 2;

unsigned baseOffset = 0;

auto copyFunc = [&](Type *copyTy, unsigned copySize) {
Value *offset =
builder.CreateAdd(builder.getInt32(baseOffset), builder.CreateMul(threadIndex, builder.getInt32(copySize)));
Value *dstPtr = builder.CreateGEP(builder.getInt8Ty(), dst, offset);
Value *srcPtr = builder.CreateGEP(builder.getInt8Ty(), src, offset);
Value *data = builder.CreateLoad(copyTy, srcPtr);
builder.CreateStore(data, dstPtr);
};

unsigned wideDwordsCopySize = sizeof(unsigned) * wideDwords;
Type *wideDwordsTy = ArrayType::get(builder.getInt32Ty(), wideDwords);
while (baseOffset + wideDwordsCopySize * scopeSize <= len) {
copyFunc(wideDwordsTy, wideDwordsCopySize);
baseOffset += wideDwordsCopySize * scopeSize;
}

unsigned dwordCopySize = sizeof(unsigned);
Type *dwordTy = builder.getInt32Ty();
while (baseOffset + dwordCopySize * scopeSize <= len) {
copyFunc(dwordTy, dwordCopySize);
baseOffset += dwordCopySize * scopeSize;
}

unsigned remainingBytes = len - baseOffset;

if (remainingBytes) {
assert(remainingBytes % 4 == 0);
BasicBlock *afterBlock = groupMemcpyOp.getParent();
BasicBlock *beforeBlock = splitBlockBefore(afterBlock, &groupMemcpyOp, nullptr, nullptr, nullptr);
beforeBlock->takeName(afterBlock);
afterBlock->setName(Twine(beforeBlock->getName()) + ".afterGroupMemcpyTail");

// Split to create a tail copy block, empty except for an unconditional branch to afterBlock.
BasicBlock *copyBlock = splitBlockBefore(afterBlock, &groupMemcpyOp, nullptr, nullptr, nullptr, ".groupMemcpyTail");
// Change the branch at the end of beforeBlock to be conditional.
beforeBlock->getTerminator()->eraseFromParent();
builder.SetInsertPoint(beforeBlock);

Value *indexInRange = builder.CreateICmpULT(threadIndex, builder.getInt32(remainingBytes / 4));

builder.CreateCondBr(indexInRange, copyBlock, afterBlock);
// Create the copy instructions.
builder.SetInsertPoint(copyBlock->getTerminator());
copyFunc(dwordTy, dwordCopySize);
}
commonProcessGroupMemcpy(groupMemcpyOp, builder, threadIndex, scopeSize);
}

// =====================================================================================================================
Expand Down

0 comments on commit bd8a42b

Please sign in to comment.