Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lgc: add dialect GroupMemcpyOp #2802

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lgc/include/lgc/patch/PatchEntryPointMutate.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#pragma once

#include "lgc/LgcCpsDialect.h"
#include "lgc/LgcDialect.h"
#include "lgc/patch/Patch.h"
#include "lgc/patch/ShaderInputs.h"
#include "lgc/state/PipelineShaders.h"
Expand Down Expand Up @@ -161,6 +162,9 @@ class PatchEntryPointMutate : public Patch, public llvm::PassInfoMixin<PatchEntr

bool isComputeWithCalls() const;

void processGroupMemcpy(llvm::Module &module);
void lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp);

bool m_hasTs; // Whether the pipeline has tessllation shader
bool m_hasGs; // Whether the pipeline has geometry shader
PipelineState *m_pipelineState = nullptr; // Pipeline state from PipelineStateWrapper pass
Expand Down
19 changes: 19 additions & 0 deletions lgc/interface/lgc/LgcDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,22 @@ def UserDataOp : LgcOp<"user.data", [Memory<[]>, WillReturn]> {
This operation is used for push constants in Vulkan and in some cases by OpenGL.
}];
}

def GroupMemcpyOp : LgcOp<"group.memcpy", [Memory<[]>]> {
let arguments = (ins PointerType:$dst, PointerType:$src, AttrI32:$size, AttrI32:$scope);
let results = (outs);

let summary = "copy a memory area cooperatively using the threads of a group";
let description = [{
Only usable in compute-like shader types (compute, task, mesh).

Use all threads of a group (workgroup or subgroup aka wave) to copy `size` bytes
from `src` to `dst`.

`dst`, `src`, and `size` must be uniform at the given scope.

`scope` is 2 for workgroup scope and 3 for subgroup (wave) scope. No other values are allowed.

This operation must only occur in control flow that is uniform for the relevant scope.
}];
}
202 changes: 202 additions & 0 deletions lgc/patch/PatchEntryPointMutate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@
*/

#include "lgc/patch/PatchEntryPointMutate.h"
#include "ShaderMerger.h"
#include "lgc/LgcContext.h"
#include "lgc/LgcDialect.h"
#include "lgc/builder/BuilderImpl.h"
#include "lgc/patch/ShaderInputs.h"
#include "lgc/state/AbiMetadata.h"
#include "lgc/state/AbiUnlinked.h"
Expand All @@ -73,6 +75,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <optional>

#define DEBUG_TYPE "lgc-patch-entry-point-mutate"
Expand Down Expand Up @@ -176,6 +179,8 @@ bool PatchEntryPointMutate::runImpl(Module &module, PipelineShadersResult &pipel
shaderInputs.fixupUses(*m_module, m_pipelineState, isComputeWithCalls());

m_cpsShaderInputCache.clear();

processGroupMemcpy(module);
return true;
}

Expand Down Expand Up @@ -254,6 +259,203 @@ static Value *mergeDwordsIntoVector(IRBuilder<> &builder, ArrayRef<Value *> inpu
return vec;
}

// =====================================================================================================================
// Lower GroupMemcpyOp
void PatchEntryPointMutate::processGroupMemcpy(Module &module) {
SmallVector<CallInst *> tobeErased;
xazhangAMD marked this conversation as resolved.
Show resolved Hide resolved
struct Payload {
SmallVectorImpl<CallInst *> &tobeErased;
PatchEntryPointMutate *self;
};
Payload payload = {tobeErased, this};

static auto visitor = llvm_dialects::VisitorBuilder<Payload>()
.setStrategy(llvm_dialects::VisitorStrategy::ByFunctionDeclaration)
.add<GroupMemcpyOp>([](auto &payload, auto &op) {
payload.self->lowerGroupMemcpy(op);
payload.tobeErased.push_back(&op);
})
.build();
visitor.visit(payload, module);
for (auto call : payload.tobeErased)
call->eraseFromParent();
}

// =====================================================================================================================
// Lower GroupMemcpyOp - Copy memory using threads in a workgroup (scope=2) or subgroup (scope=3).
void PatchEntryPointMutate::lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is a good idea to place the handling of group memcpy here because the pass is aimed to handle entry-point mutation. Other responsibilities should be moved to other passes or even creating a new pass according to LLVM design philosophy.

The handling of task/mesh shader could be put in MeshTaskShader.cpp. For CS, many operations are straightforward, if possible, maybe we can place it on InOutBuilder since readCsBuiltIn() can read back any CS built-in so you can do anything you want. This is true for task shader as well and share the handling. If that is impossible, we can move the handling of CS to PatchInOutImportExport.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a similar thought in the internal review - split the op into 2: one for task/mesh and another for compute.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check shader stage when lowering this op so we can differentiate its usages in task/mesh shader or in compute shader. If you decide to make two dedicated ops, that is fine as well.

I discussed this with Ruiling, we both believe the lowering of this op is better to be moved to other appropriate passes other than this pass. Also, if you can share us with a LGC file (.lgc generated by frontend with the option --emit-lgc) showing the usage of this op we can better evaluate your future refactoring change in the review.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is natural to handle task/mesh shader in MeshTaskShader.cpp but for CS, it is also weird to place the code in InOutBuilder or PatchInOutImportExport. Actually my only intent is to use this for task shader only. The llpcfe standalone tool doesn't seem to support -emit-lgc but the dump should have all the information you need.

I modified code in a commit here.

And a pipeline dump attached.
PipelineTaskMesh_0xB812ED624A368A8F.txt

Copy link
Contributor

@amdrexu amdrexu Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I see the usage. Your requirement is similar to the usage of PatchInitializeWorkgroupMemory::initializeWithZero. We add a new pass to handle this. Anyway, it is fine to keep CS handling in entry-point mutation as a temporary solution and move the handling of task/mesh shader to MeshTaskShader class. I plan to rework the pass PatchInitializeWorkgroupMemory and try to enable the usage for your case in the near future.

BuilderImpl builder(m_pipelineState);
Function *entryPoint = groupMemcpyOp.getFunction();
ShaderStage stage = getShaderStage(entryPoint);
builder.setShaderStage(stage);
builder.SetInsertPoint(&groupMemcpyOp);

auto gfxIp = m_pipelineState->getTargetInfo().getGfxIpVersion();

auto dst = groupMemcpyOp.getDst();
auto src = groupMemcpyOp.getSrc();
auto len = groupMemcpyOp.getSize();
auto scope = groupMemcpyOp.getScope();

unsigned scopeSize = 0;
Value *threadIndex = nullptr;

if (scope == 2) {
unsigned workgroupSize[3] = {};
auto shaderModes = m_pipelineState->getShaderModes();
if (stage == ShaderStageTask || stage == ShaderStageCompute) {
Module &module = *groupMemcpyOp.getModule();
workgroupSize[0] = shaderModes->getComputeShaderMode(module).workgroupSizeX;
workgroupSize[1] = shaderModes->getComputeShaderMode(module).workgroupSizeY;
workgroupSize[2] = shaderModes->getComputeShaderMode(module).workgroupSizeZ;
} else if (stage == ShaderStageMesh) {
workgroupSize[0] = shaderModes->getMeshShaderMode().workgroupSizeX;
workgroupSize[1] = shaderModes->getMeshShaderMode().workgroupSizeY;
workgroupSize[2] = shaderModes->getMeshShaderMode().workgroupSizeZ;
} else {
llvm_unreachable("Invalid shade stage!");
}

// LocalInvocationId is a function argument now and CreateReadBuiltInInput cannot retrieve it.
unsigned argIndex = 0xFFFFFFFF;
switch (stage) {
case ShaderStageTask: {
auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageTask)->entryArgIdxs.task;
argIndex = entryArgIdxs.localInvocationId;
break;
}
case ShaderStageMesh: {
auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageTask)->entryArgIdxs.mesh;
xazhangAMD marked this conversation as resolved.
Show resolved Hide resolved
argIndex = entryArgIdxs.localInvocationId;
break;
}
case ShaderStageCompute: {
auto &entryArgIdxs = m_pipelineState->getShaderInterfaceData(ShaderStageTask)->entryArgIdxs.cs;
xazhangAMD marked this conversation as resolved.
Show resolved Hide resolved
argIndex = entryArgIdxs.localInvocationId;
break;
}
default:
llvm_unreachable("Invalid shade stage!");
break;
}

const unsigned waveSize = m_pipelineState->getShaderWaveSize(stage);

// For mesh shader the following two ids are required.
Value *waveIdInSubgroupMesh = nullptr;
Value *threadIdInWaveMesh = nullptr;
if (stage == ShaderStageMesh) {
builder.CreateIntrinsic(Intrinsic::amdgcn_init_exec, {}, builder.getInt64(-1));
// waveId = mergedWaveInfo[27:24]
Value *mergedWaveInfo =
getFunctionArgument(entryPoint, ShaderMerger::getSpecialSgprInputIndex(gfxIp, EsGs::MergedWaveInfo));
waveIdInSubgroupMesh = builder.CreateAnd(builder.CreateLShr(mergedWaveInfo, 24), 0xF, "waveIdInSubgroupMesh");

threadIdInWaveMesh =
builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {}, {builder.getInt32(-1), builder.getInt32(0)});
if (waveSize == 64) {
threadIdInWaveMesh =
builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_hi, {}, {builder.getInt32(-1), threadIdInWaveMesh});
}
threadIdInWaveMesh->setName("threadIdInWaveMesh");
}

unsigned workgroupTotalSize = workgroupSize[0] * workgroupSize[1] * workgroupSize[2];

scopeSize = workgroupTotalSize;

// localInvocationId argument for mesh shader is available from GFX11+. But it can be retrieved in anther way.
if (stage == ShaderStageMesh) {
threadIndex = builder.CreateAdd(builder.CreateMul(waveIdInSubgroupMesh, builder.getInt32(waveSize)),
threadIdInWaveMesh, "threadIdInSubgroupMesh");
} else {
Value *threadIdInGroup = getFunctionArgument(entryPoint, argIndex);
Value *threadIdComp[3];
if (gfxIp.major < 11) {
for (unsigned idx = 0; idx < 3; idx++)
threadIdComp[idx] = builder.CreateExtractElement(threadIdInGroup, idx);
} else {
// The local invocation ID is packed to VGPR0 on GFX11+ with the following layout:
//
// +-----------------------+-----------------------+-----------------------+
// | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID Z |
// | [29:20] | [19:10] | [9:0] |
// +-----------------------+-----------------------+-----------------------+
// localInvocationIdZ = localInvocationId[29:20]
threadIdComp[2] = builder.CreateAnd(builder.CreateLShr(threadIdInGroup, 20), 0x3FF, "localInvocationIdZ");
// localInvocationIdY = localInvocationId[19:10]
threadIdComp[1] = builder.CreateAnd(builder.CreateLShr(threadIdInGroup, 10), 0x3FF, "localInvocationIdY");
// localInvocationIdX = localInvocationId[9:0]
threadIdComp[0] = builder.CreateAnd(threadIdInGroup, 0x3FF, "localInvocationIdX");
}

// LocalInvocationIndex is
// (LocalInvocationId.Z * WorkgroupSize.Y + LocalInvocationId.Y) * WorkGroupSize.X + LocalInvocationId.X
threadIndex = builder.CreateMul(threadIdComp[2], builder.getInt32(workgroupSize[1]));
threadIndex = builder.CreateAdd(threadIndex, threadIdComp[1]);
threadIndex = builder.CreateMul(threadIndex, builder.getInt32(workgroupSize[0]));
threadIndex = builder.CreateAdd(threadIndex, threadIdComp[0]);
}
} else {
llvm_unreachable("Unsupported scope!");
}

// Copy in 16-bytes if possible
unsigned wideDwords = 4;
// If either pointer is in LDS, copy in 8-bytes
if (src->getType()->getPointerAddressSpace() == ADDR_SPACE_LOCAL ||
dst->getType()->getPointerAddressSpace() == ADDR_SPACE_LOCAL)
wideDwords = 2;

unsigned baseOffset = 0;

auto copyFunc = [&](Type *copyTy, unsigned copySize) {
Value *offset =
builder.CreateAdd(builder.getInt32(baseOffset), builder.CreateMul(threadIndex, builder.getInt32(copySize)));
Value *dstPtr = builder.CreateGEP(builder.getInt8Ty(), dst, offset);
Value *srcPtr = builder.CreateGEP(builder.getInt8Ty(), src, offset);
Value *data = builder.CreateLoad(copyTy, srcPtr);
builder.CreateStore(data, dstPtr);
};

unsigned wideDwordsCopySize = sizeof(unsigned) * wideDwords;
Type *wideDwordsTy = ArrayType::get(builder.getInt32Ty(), wideDwords);
while (baseOffset + wideDwordsCopySize * scopeSize <= len) {
copyFunc(wideDwordsTy, wideDwordsCopySize);
baseOffset += wideDwordsCopySize * scopeSize;
}

unsigned dwordCopySize = sizeof(unsigned);
Type *dwordTy = builder.getInt32Ty();
while (baseOffset + dwordCopySize * scopeSize <= len) {
copyFunc(dwordTy, dwordCopySize);
baseOffset += dwordCopySize * scopeSize;
}

unsigned remainingBytes = len - baseOffset;

if (remainingBytes) {
assert(remainingBytes % 4 == 0);
BasicBlock *afterBlock = groupMemcpyOp.getParent();
BasicBlock *beforeBlock = splitBlockBefore(afterBlock, &groupMemcpyOp, nullptr, nullptr, nullptr);
beforeBlock->takeName(afterBlock);
afterBlock->setName(Twine(beforeBlock->getName()) + ".afterGroupMemcpyTail");

// Split to create a tail copy block, empty except for an unconditional branch to afterBlock.
BasicBlock *copyBlock = splitBlockBefore(afterBlock, &groupMemcpyOp, nullptr, nullptr, nullptr, ".groupMemcpyTail");
// Change the branch at the end of beforeBlock to be conditional.
beforeBlock->getTerminator()->eraseFromParent();
builder.SetInsertPoint(beforeBlock);

Value *indexInRange = builder.CreateICmpULT(threadIndex, builder.getInt32(remainingBytes / 4));

builder.CreateCondBr(indexInRange, copyBlock, afterBlock);
// Create the copy instructions.
builder.SetInsertPoint(copyBlock->getTerminator());
copyFunc(dwordTy, dwordCopySize);
}
}

// =====================================================================================================================
// Lower as.continuation.reference call.
//
Expand Down