Skip to content

Commit

Permalink
lgc: add dialect GroupMemcpyOp (#2802)
Browse files Browse the repository at this point in the history
  • Loading branch information
xazhangAMD authored Nov 8, 2023
1 parent 071f063 commit 8e7a79b
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lgc/include/lgc/patch/PatchEntryPointMutate.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#pragma once

#include "lgc/LgcCpsDialect.h"
#include "lgc/LgcDialect.h"
#include "lgc/patch/Patch.h"
#include "lgc/patch/ShaderInputs.h"
#include "lgc/state/PipelineShaders.h"
Expand Down Expand Up @@ -161,6 +162,9 @@ class PatchEntryPointMutate : public Patch, public llvm::PassInfoMixin<PatchEntr

bool isComputeWithCalls() const;

void processGroupMemcpy(llvm::Module &module);
void lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp);

bool m_hasTs; // Whether the pipeline has tessllation shader
bool m_hasGs; // Whether the pipeline has geometry shader
PipelineState *m_pipelineState = nullptr; // Pipeline state from PipelineStateWrapper pass
Expand Down
19 changes: 19 additions & 0 deletions lgc/interface/lgc/LgcDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,22 @@ def UserDataOp : LgcOp<"user.data", [Memory<[]>, WillReturn]> {
This operation is used for push constants in Vulkan and in some cases by OpenGL.
}];
}

def GroupMemcpyOp : LgcOp<"group.memcpy", [Memory<[]>]> {
let arguments = (ins PointerType:$dst, PointerType:$src, AttrI32:$size, AttrI32:$scope);
let results = (outs);

let summary = "copy a memory area cooperatively using the threads of a group";
let description = [{
Only usable in compute-like shader types (compute, task, mesh).

Use all threads of a group (workgroup or subgroup aka wave) to copy `size` bytes
from `src` to `dst`.

`dst`, `src`, and `size` must be uniform at the given scope.

`scope` is 2 for workgroup scope and 3 for subgroup (wave) scope. No other values are allowed.

This operation must only occur in control flow that is uniform for the relevant scope.
}];
}
202 changes: 202 additions & 0 deletions lgc/patch/PatchEntryPointMutate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@
*/

#include "lgc/patch/PatchEntryPointMutate.h"
#include "ShaderMerger.h"
#include "lgc/LgcContext.h"
#include "lgc/LgcDialect.h"
#include "lgc/builder/BuilderImpl.h"
#include "lgc/patch/ShaderInputs.h"
#include "lgc/state/AbiMetadata.h"
#include "lgc/state/AbiUnlinked.h"
Expand All @@ -73,6 +75,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <optional>

#define DEBUG_TYPE "lgc-patch-entry-point-mutate"
Expand Down Expand Up @@ -176,6 +179,8 @@ bool PatchEntryPointMutate::runImpl(Module &module, PipelineShadersResult &pipel
shaderInputs.fixupUses(*m_module, m_pipelineState, isComputeWithCalls());

m_cpsShaderInputCache.clear();

processGroupMemcpy(module);
return true;
}

Expand Down Expand Up @@ -254,6 +259,203 @@ static Value *mergeDwordsIntoVector(IRBuilder<> &builder, ArrayRef<Value *> inpu
return vec;
}

// =====================================================================================================================
// Lower GroupMemcpyOp
void PatchEntryPointMutate::processGroupMemcpy(Module &module) {
SmallVector<CallInst *> tobeErased;
struct Payload {
SmallVectorImpl<CallInst *> &tobeErased;
PatchEntryPointMutate *self;
};
Payload payload = {tobeErased, this};

static auto visitor = llvm_dialects::VisitorBuilder<Payload>()
.setStrategy(llvm_dialects::VisitorStrategy::ByFunctionDeclaration)
.add<GroupMemcpyOp>([](auto &payload, auto &op) {
payload.self->lowerGroupMemcpy(op);
payload.tobeErased.push_back(&op);
})
.build();
visitor.visit(payload, module);
for (auto call : payload.tobeErased)
call->eraseFromParent();
}

// =====================================================================================================================
// Lower GroupMemcpyOp - Copy memory using threads in a workgroup (scope=2) or subgroup (scope=3).
void PatchEntryPointMutate::lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp) {
BuilderImpl builder(m_pipelineState);
Function *entryPoint = groupMemcpyOp.getFunction();
ShaderStage stage = getShaderStage(entryPoint);
builder.setShaderStage(stage);
builder.SetInsertPoint(&groupMemcpyOp);

auto gfxIp = m_pipelineState->getTargetInfo().getGfxIpVersion();

auto dst = groupMemcpyOp.getDst();
auto src = groupMemcpyOp.getSrc();
auto len = groupMemcpyOp.getSize();
auto scope = groupMemcpyOp.getScope();

unsigned scopeSize = 0;
Value *threadIndex = nullptr;

if (scope == 2) {
unsigned workgroupSize[3] = {};
auto shaderModes = m_pipelineState->getShaderModes();
if (stage == ShaderStageTask || stage == ShaderStageCompute) {
Module &module = *groupMemcpyOp.getModule();
workgroupSize[0] = shaderModes->getComputeShaderMode(module).workgroupSizeX;
workgroupSize[1] = shaderModes->getComputeShaderMode(module).workgroupSizeY;
workgroupSize[2] = shaderModes->getComputeShaderMode(module).workgroupSizeZ;
} else if (stage == ShaderStageMesh) {
workgroupSize[0] = shaderModes->getMeshShaderMode().workgroupSizeX;
workgroupSize[1] = shaderModes->getMeshShaderMode().workgroupSizeY;
workgroupSize[2] = shaderModes->getMeshShaderMode().workgroupSizeZ;
} else {
llvm_unreachable("Invalid shade stage!");
}

// LocalInvocationId is a function argument now and CreateReadBuiltInInput cannot retrieve it.
unsigned argIndex = 0xFFFFFFFF;
switch (stage) {
case ShaderStageTask: {
auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageTask)->entryArgIdxs.task;
argIndex = entryArgIdxs.localInvocationId;
break;
}
case ShaderStageMesh: {
auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageTask)->entryArgIdxs.mesh;
argIndex = entryArgIdxs.localInvocationId;
break;
}
case ShaderStageCompute: {
auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageTask)->entryArgIdxs.cs;
argIndex = entryArgIdxs.localInvocationId;
break;
}
default:
llvm_unreachable("Invalid shade stage!");
break;
}

const unsigned waveSize = m_pipelineState->getShaderWaveSize(stage);

// For mesh shader the following two ids are required.
Value *waveIdInSubgroupMesh = nullptr;
Value *threadIdInWaveMesh = nullptr;
if (stage == ShaderStageMesh) {
builder.CreateIntrinsic(Intrinsic::amdgcn_init_exec, {}, builder.getInt64(-1));
// waveId = mergedWaveInfo[27:24]
Value *mergedWaveInfo =
getFunctionArgument(entryPoint, ShaderMerger::getSpecialSgprInputIndex(gfxIp, EsGs::MergedWaveInfo));
waveIdInSubgroupMesh = builder.CreateAnd(builder.CreateLShr(mergedWaveInfo, 24), 0xF, "waveIdInSubgroupMesh");

threadIdInWaveMesh =
builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {}, {builder.getInt32(-1), builder.getInt32(0)});
if (waveSize == 64) {
threadIdInWaveMesh =
builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_hi, {}, {builder.getInt32(-1), threadIdInWaveMesh});
}
threadIdInWaveMesh->setName("threadIdInWaveMesh");
}

unsigned workgroupTotalSize = workgroupSize[0] * workgroupSize[1] * workgroupSize[2];

scopeSize = workgroupTotalSize;

// localInvocationId argument for mesh shader is available from GFX11+. But it can be retrieved in anther way.
if (stage == ShaderStageMesh) {
threadIndex = builder.CreateAdd(builder.CreateMul(waveIdInSubgroupMesh, builder.getInt32(waveSize)),
threadIdInWaveMesh, "threadIdInSubgroupMesh");
} else {
Value *threadIdInGroup = getFunctionArgument(entryPoint, argIndex);
Value *threadIdComp[3];
if (gfxIp.major < 11) {
for (unsigned idx = 0; idx < 3; idx++)
threadIdComp[idx] = builder.CreateExtractElement(threadIdInGroup, idx);
} else {
// The local invocation ID is packed to VGPR0 on GFX11+ with the following layout:
//
// +-----------------------+-----------------------+-----------------------+
// | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID Z |
// | [29:20] | [19:10] | [9:0] |
// +-----------------------+-----------------------+-----------------------+
// localInvocationIdZ = localInvocationId[29:20]
threadIdComp[2] = builder.CreateAnd(builder.CreateLShr(threadIdInGroup, 20), 0x3FF, "localInvocationIdZ");
// localInvocationIdY = localInvocationId[19:10]
threadIdComp[1] = builder.CreateAnd(builder.CreateLShr(threadIdInGroup, 10), 0x3FF, "localInvocationIdY");
// localInvocationIdX = localInvocationId[9:0]
threadIdComp[0] = builder.CreateAnd(threadIdInGroup, 0x3FF, "localInvocationIdX");
}

// LocalInvocationIndex is
// (LocalInvocationId.Z * WorkgroupSize.Y + LocalInvocationId.Y) * WorkGroupSize.X + LocalInvocationId.X
threadIndex = builder.CreateMul(threadIdComp[2], builder.getInt32(workgroupSize[1]));
threadIndex = builder.CreateAdd(threadIndex, threadIdComp[1]);
threadIndex = builder.CreateMul(threadIndex, builder.getInt32(workgroupSize[0]));
threadIndex = builder.CreateAdd(threadIndex, threadIdComp[0]);
}
} else {
llvm_unreachable("Unsupported scope!");
}

// Copy in 16-bytes if possible
unsigned wideDwords = 4;
// If either pointer is in LDS, copy in 8-bytes
if (src->getType()->getPointerAddressSpace() == ADDR_SPACE_LOCAL ||
dst->getType()->getPointerAddressSpace() == ADDR_SPACE_LOCAL)
wideDwords = 2;

unsigned baseOffset = 0;

auto copyFunc = [&](Type *copyTy, unsigned copySize) {
Value *offset =
builder.CreateAdd(builder.getInt32(baseOffset), builder.CreateMul(threadIndex, builder.getInt32(copySize)));
Value *dstPtr = builder.CreateGEP(builder.getInt8Ty(), dst, offset);
Value *srcPtr = builder.CreateGEP(builder.getInt8Ty(), src, offset);
Value *data = builder.CreateLoad(copyTy, srcPtr);
builder.CreateStore(data, dstPtr);
};

unsigned wideDwordsCopySize = sizeof(unsigned) * wideDwords;
Type *wideDwordsTy = ArrayType::get(builder.getInt32Ty(), wideDwords);
while (baseOffset + wideDwordsCopySize * scopeSize <= len) {
copyFunc(wideDwordsTy, wideDwordsCopySize);
baseOffset += wideDwordsCopySize * scopeSize;
}

unsigned dwordCopySize = sizeof(unsigned);
Type *dwordTy = builder.getInt32Ty();
while (baseOffset + dwordCopySize * scopeSize <= len) {
copyFunc(dwordTy, dwordCopySize);
baseOffset += dwordCopySize * scopeSize;
}

unsigned remainingBytes = len - baseOffset;

if (remainingBytes) {
assert(remainingBytes % 4 == 0);
BasicBlock *afterBlock = groupMemcpyOp.getParent();
BasicBlock *beforeBlock = splitBlockBefore(afterBlock, &groupMemcpyOp, nullptr, nullptr, nullptr);
beforeBlock->takeName(afterBlock);
afterBlock->setName(Twine(beforeBlock->getName()) + ".afterGroupMemcpyTail");

// Split to create a tail copy block, empty except for an unconditional branch to afterBlock.
BasicBlock *copyBlock = splitBlockBefore(afterBlock, &groupMemcpyOp, nullptr, nullptr, nullptr, ".groupMemcpyTail");
// Change the branch at the end of beforeBlock to be conditional.
beforeBlock->getTerminator()->eraseFromParent();
builder.SetInsertPoint(beforeBlock);

Value *indexInRange = builder.CreateICmpULT(threadIndex, builder.getInt32(remainingBytes / 4));

builder.CreateCondBr(indexInRange, copyBlock, afterBlock);
// Create the copy instructions.
builder.SetInsertPoint(copyBlock->getTerminator());
copyFunc(dwordTy, dwordCopySize);
}
}

// =====================================================================================================================
// Lower as.continuation.reference call.
//
Expand Down

0 comments on commit 8e7a79b

Please sign in to comment.