diff --git a/lgc/patch/PatchEntryPointMutate.cpp b/lgc/patch/PatchEntryPointMutate.cpp index f5140664ca..96b5899e72 100644 --- a/lgc/patch/PatchEntryPointMutate.cpp +++ b/lgc/patch/PatchEntryPointMutate.cpp @@ -378,7 +378,7 @@ void PatchEntryPointMutate::lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp) { // The local invocation ID is packed to VGPR0 on GFX11+ with the following layout: // // +-----------------------+-----------------------+-----------------------+ - // | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID Z | + // | Local Invocation ID Z | Local Invocation ID Y | Local Invocation ID X | // | [29:20] | [19:10] | [9:0] | // +-----------------------+-----------------------+-----------------------+ // localInvocationIdZ = localInvocationId[29:20] @@ -391,9 +391,13 @@ void PatchEntryPointMutate::lowerGroupMemcpy(GroupMemcpyOp &groupMemcpyOp) { // LocalInvocationIndex is // (LocalInvocationId.Z * WorkgroupSize.Y + LocalInvocationId.Y) * WorkGroupSize.X + LocalInvocationId.X - threadIndex = builder.CreateMul(threadIdComp[2], builder.getInt32(workgroupSize[1])); - threadIndex = builder.CreateAdd(threadIndex, threadIdComp[1]); - threadIndex = builder.CreateMul(threadIndex, builder.getInt32(workgroupSize[0])); + // tidigCompCnt is not always 3 if groupSizeY and/or groupSizeZ are 1. See RegisterMetadataBuilder.cpp. + threadIndex = builder.getInt32(0); + if (workgroupSize[2] > 1) + threadIndex = builder.CreateMul(threadIdComp[2], builder.getInt32(workgroupSize[1])); + if (workgroupSize[1] > 1) + threadIndex = + builder.CreateMul(builder.CreateAdd(threadIndex, threadIdComp[1]), builder.getInt32(workgroupSize[0])); threadIndex = builder.CreateAdd(threadIndex, threadIdComp[0]); } } else {