Skip to content

Commit

Permalink
For indirect mesh shaders fetch up-to-date dispatch arguments
Browse files Browse the repository at this point in the history
* This doesn't work for whole-pass mesh fetch, but and will cause
  inconsistencies between single draws due to non-determinism, but should avoid
  crashes.
  • Loading branch information
baldurk committed Dec 7, 2023
1 parent fd84a06 commit 4bccf2d
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 20 deletions.
10 changes: 7 additions & 3 deletions qrenderdoc/Windows/BufferViewer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ struct BufferConfiguration
BufferData *indices = NULL;
int32_t baseVertex = 0;

rdcfixedarray<uint32_t, 3> dispatchSize;
rdcarray<TaskGroupSize> taskSizes;
rdcarray<uint32_t> meshletVertexPrefixCounts;
uint32_t taskOrMeshletOffset = 0;
Expand Down Expand Up @@ -562,6 +563,7 @@ struct BufferConfiguration

baseVertex = o.baseVertex;
meshletVertexPrefixCounts = o.meshletVertexPrefixCounts;
dispatchSize = o.dispatchSize;
taskSizes = o.taskSizes;
taskOrMeshletOffset = o.taskOrMeshletOffset;
perPrimitiveOffset = o.perPrimitiveOffset;
Expand Down Expand Up @@ -595,6 +597,7 @@ struct BufferConfiguration
b->deref();

meshletVertexPrefixCounts.clear();
dispatchSize = {};
taskSizes.clear();

buffers.clear();
Expand Down Expand Up @@ -1938,6 +1941,7 @@ static void RT_FetchMeshPipeData(IReplayController *r, ICaptureContext &ctx, Pop
data->out1Config.displayIndices->deref();
data->out1Config.displayIndices = NULL;

data->out1Config.dispatchSize = data->postOut1.dispatchSize;
data->out1Config.taskSizes = data->postOut1.taskSizes;

if(data->postOut1.vertexResourceId != ResourceId())
Expand Down Expand Up @@ -3753,11 +3757,11 @@ void BufferViewer::OnEventChanged(uint32_t eventId)
const ActionDescription *action = m_Ctx.CurAction();

uint32_t i = 0;
for(uint32_t x = 0; x < action->dispatchDimension[0]; x++)
for(uint32_t x = 0; x < bufdata->out1Config.dispatchSize[0]; x++)
{
for(uint32_t y = 0; y < action->dispatchDimension[1]; y++)
for(uint32_t y = 0; y < bufdata->out1Config.dispatchSize[1]; y++)
{
for(uint32_t z = 0; z < action->dispatchDimension[2]; z++)
for(uint32_t z = 0; z < bufdata->out1Config.dispatchSize[2]; z++)
{
TaskGroupSize size = bufdata->out1Config.taskSizes[i];

Expand Down
13 changes: 13 additions & 0 deletions renderdoc/api/replay/control_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,19 @@ the number of vertices).
)");
rdcarray<MeshletSize> meshletSizes;

DOCUMENT(R"(The size of the dispatch that launched a meshlet based draw.
Only valid for the task stage if task shaders are used.
.. note::
This is present because the dispatch size at the time of the mesh output fetch may not match the
:data:`ActionDescription.dispatchDimension` due to non-determinism in the capture. Being present
here allows the replay to process the mesh output validly in itself without seeing a mismatch.
:type: Tuple[int,int,int]
)");
rdcfixedarray<uint32_t, 3> dispatchSize;

DOCUMENT(R"(The size of each task group's dispatch, for a meshlet based draw.
Each group of a task shader within a dispatch can itself fill out a payload and dispatch a number
Expand Down
22 changes: 22 additions & 0 deletions renderdoc/driver/d3d12/d3d12_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,28 @@ StencilOperation MakeStencilOp(D3D12_STENCIL_OP op)
return StencilOperation::Keep;
}

uint32_t ArgumentTypeByteSize(const D3D12_INDIRECT_ARGUMENT_DESC &arg)
{
switch(arg.Type)
{
case D3D12_INDIRECT_ARGUMENT_TYPE_DRAW: return sizeof(D3D12_DRAW_ARGUMENTS);
case D3D12_INDIRECT_ARGUMENT_TYPE_DRAW_INDEXED: return sizeof(D3D12_DRAW_INDEXED_ARGUMENTS);
case D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH: return sizeof(D3D12_DISPATCH_ARGUMENTS);
case D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH_MESH: return sizeof(D3D12_DISPATCH_MESH_ARGUMENTS);
case D3D12_INDIRECT_ARGUMENT_TYPE_CONSTANT:
return sizeof(uint32_t) * arg.Constant.Num32BitValuesToSet;
case D3D12_INDIRECT_ARGUMENT_TYPE_VERTEX_BUFFER_VIEW: return sizeof(D3D12_VERTEX_BUFFER_VIEW);
case D3D12_INDIRECT_ARGUMENT_TYPE_INDEX_BUFFER_VIEW: return sizeof(D3D12_INDEX_BUFFER_VIEW);
case D3D12_INDIRECT_ARGUMENT_TYPE_CONSTANT_BUFFER_VIEW:
case D3D12_INDIRECT_ARGUMENT_TYPE_SHADER_RESOURCE_VIEW:
case D3D12_INDIRECT_ARGUMENT_TYPE_UNORDERED_ACCESS_VIEW:
return sizeof(D3D12_GPU_VIRTUAL_ADDRESS);
default: RDCERR("Unexpected argument type! %d", arg.Type); break;
}

return 0;
}

UINT GetResourceNumMipLevels(const D3D12_RESOURCE_DESC *desc)
{
switch(desc->Dimension)
Expand Down
2 changes: 2 additions & 0 deletions renderdoc/driver/d3d12/d3d12_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ BlendMultiplier MakeBlendMultiplier(D3D12_BLEND blend, bool alpha);
BlendOperation MakeBlendOp(D3D12_BLEND_OP op);
StencilOperation MakeStencilOp(D3D12_STENCIL_OP op);

uint32_t ArgumentTypeByteSize(const D3D12_INDIRECT_ARGUMENT_DESC &arg);

// wrapper around D3D12_RESOURCE_STATES and D3D12_BARRIER_LAYOUT to handle resources that could be
// in either and varying support
struct D3D12ResourceLayout
Expand Down
70 changes: 58 additions & 12 deletions renderdoc/driver/d3d12/d3d12_postvs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2056,8 +2056,7 @@ void D3D12Replay::InitPostMSBuffers(uint32_t eventId)

const ActionDescription *action = m_pDevice->GetAction(eventId);

uint32_t totalNumMeshlets =
action->dispatchDimension[0] * action->dispatchDimension[1] * action->dispatchDimension[2];
rdcfixedarray<uint32_t, 3> dispatchSize = action->dispatchDimension;

D3D12RenderState &rs = m_pDevice->GetQueue()->GetCommandData()->m_RenderState;

Expand All @@ -2067,6 +2066,50 @@ void D3D12Replay::InitPostMSBuffers(uint32_t eventId)
(WrappedID3D12PipelineState *)rm->GetCurrentAs<ID3D12PipelineState>(rs.pipe);
D3D12RootSignature modsig;

// for indirect dispatches, fetch up to date dispatch sizes in case they're non-deterministic
if(action->flags & ActionFlags::Indirect)
{
uint32_t chunkIdx = action->events.back().chunkIndex;
uint32_t parentIdx = action->parent->events.back().chunkIndex;
const SDFile *file = m_pDevice->GetStructuredFile();

if(chunkIdx < file->chunks.size() && parentIdx < file->chunks.size())
{
const SDChunk *chunk = file->chunks[chunkIdx];
const SDChunk *parentChunk = file->chunks[parentIdx];

uint32_t cmdIdx = chunk->FindChild("CommandIndex")->AsUInt32();
uint32_t argIdx = chunk->FindChild("ArgumentIndex")->AsUInt32();

WrappedID3D12CommandSignature *comSig = rm->GetLiveAs<WrappedID3D12CommandSignature>(
parentChunk->FindChild("pCommandSignature")->AsResourceId());
ID3D12Resource *argBuf =
rm->GetLiveAs<ID3D12Resource>(parentChunk->FindChild("pArgumentBuffer")->AsResourceId());
uint64_t argOffs = parentChunk->FindChild("ArgumentBufferOffset")->AsUInt64();

argOffs += cmdIdx * comSig->sig.ByteStride;

for(uint32_t i = 0; i < argIdx; i++)
argOffs += ArgumentTypeByteSize(comSig->sig.arguments[i]);

bytebuf dispatchArgs;
GetDebugManager()->GetBufferData(argBuf, argOffs, sizeof(D3D12_DISPATCH_MESH_ARGUMENTS),
dispatchArgs);

if(dispatchArgs.size() >= sizeof(D3D12_DISPATCH_MESH_ARGUMENTS))
{
D3D12_DISPATCH_MESH_ARGUMENTS *meshArgs =
(D3D12_DISPATCH_MESH_ARGUMENTS *)dispatchArgs.data();

dispatchSize[0] = meshArgs->ThreadGroupCountX;
dispatchSize[1] = meshArgs->ThreadGroupCountY;
dispatchSize[2] = meshArgs->ThreadGroupCountZ;
}
}
}

uint32_t totalNumMeshlets = dispatchSize[0] * dispatchSize[1] * dispatchSize[2];

// set defaults so that we don't try to fetch this output again if something goes wrong and the
// same event is selected again
{
Expand Down Expand Up @@ -2159,8 +2202,8 @@ void D3D12Replay::InitPostMSBuffers(uint32_t eventId)

if(pipeDesc.AS.BytecodeLength > 0)
{
AddDXILAmpShaderPayloadStores(pipe->AS()->GetDXBC(), space, action->dispatchDimension,
payloadSize, ampFetchDXIL);
AddDXILAmpShaderPayloadStores(pipe->AS()->GetDXBC(), space, dispatchSize, payloadSize,
ampFetchDXIL);

// strip the root signature, we shouldn't need it and it may no longer match and fail validation
DXBC::DXBCContainer::StripChunk(ampFetchDXIL, DXBC::FOURCC_RTS0);
Expand Down Expand Up @@ -2256,8 +2299,7 @@ void D3D12Replay::InitPostMSBuffers(uint32_t eventId)

rs.ApplyState(m_pDevice, list);

list->DispatchMesh(action->dispatchDimension[0], action->dispatchDimension[1],
action->dispatchDimension[2]);
list->DispatchMesh(dispatchSize[0], dispatchSize[1], dispatchSize[2]);

list->Close();

Expand Down Expand Up @@ -2315,8 +2357,7 @@ void D3D12Replay::InitPostMSBuffers(uint32_t eventId)

GetDebugManager()->ResetDebugAlloc();

ConvertToFixedDXILAmpFeeder(pipe->AS()->GetDXBC(), space, action->dispatchDimension,
ampFeederDXIL);
ConvertToFixedDXILAmpFeeder(pipe->AS()->GetDXBC(), space, dispatchSize, ampFeederDXIL);

// strip the root signature, we shouldn't need it and it may no longer match and fail validation
DXBC::DXBCContainer::StripChunk(ampFeederDXIL, DXBC::FOURCC_RTS0);
Expand All @@ -2329,8 +2370,8 @@ void D3D12Replay::InitPostMSBuffers(uint32_t eventId)

bytebuf meshOutputDXIL;

AddDXILMeshShaderOutputStores(pipe->MS()->GetDXBC(), space, ampBuffer != NULL,
action->dispatchDimension, layout, meshOutputDXIL);
AddDXILMeshShaderOutputStores(pipe->MS()->GetDXBC(), space, ampBuffer != NULL, dispatchSize,
layout, meshOutputDXIL);

{
// strip the root signature, we shouldn't need it and it may no longer match and fail validation
Expand Down Expand Up @@ -2447,8 +2488,7 @@ void D3D12Replay::InitPostMSBuffers(uint32_t eventId)

rs.ApplyState(m_pDevice, list);

list->DispatchMesh(action->dispatchDimension[0], action->dispatchDimension[1],
action->dispatchDimension[2]);
list->DispatchMesh(dispatchSize[0], dispatchSize[1], dispatchSize[2]);

list->Close();

Expand Down Expand Up @@ -2695,6 +2735,8 @@ void D3D12Replay::InitPostMSBuffers(uint32_t eventId)

ret.ampout.hasPosOut = false;

ret.ampout.dispatchSize = dispatchSize;

ret.meshout.buf = meshBuffer;

ret.meshout.vertStride = layout.vertStride;
Expand All @@ -2708,6 +2750,8 @@ void D3D12Replay::InitPostMSBuffers(uint32_t eventId)
ret.meshout.numVerts = totalPrims * layout.indexCountPerPrim;
ret.meshout.instData = meshletOffsets;

ret.meshout.dispatchSize = dispatchSize;

ret.meshout.instStride = 0;

ret.meshout.idxBuf = meshBuffer;
Expand Down Expand Up @@ -4020,6 +4064,8 @@ MeshFormat D3D12Replay::GetPostVSBuffers(uint32_t eventId, uint32_t instID, uint
ret.perPrimitiveStride = s.primStride;
ret.perPrimitiveOffset = s.primOffset;

ret.dispatchSize = s.dispatchSize;

if(stage == MeshDataStage::MeshOut)
{
ret.meshletSizes.resize(s.instData.size());
Expand Down
2 changes: 2 additions & 0 deletions renderdoc/driver/d3d12/d3d12_replay.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ class D3D12Replay : public IReplayDriver
uint64_t idxOffset = 0;
DXGI_FORMAT idxFmt = DXGI_FORMAT_UNKNOWN;

rdcfixedarray<uint32_t, 3> dispatchSize;

bool hasPosOut = false;

float nearPlane = 0.0f;
Expand Down
48 changes: 44 additions & 4 deletions renderdoc/driver/vulkan/vk_postvs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2813,10 +2813,44 @@ void VulkanReplay::FetchMeshOut(uint32_t eventId, VulkanRenderState &state)
{
VulkanCreationInfo &creationInfo = m_pDriver->m_CreationInfo;

const ActionDescription *action = m_pDriver->GetAction(eventId);
ActionDescription action = *m_pDriver->GetAction(eventId);

// for indirect dispatches, fetch up to date dispatch sizes in case they're non-deterministic
if(action.flags & ActionFlags::Indirect)
{
uint32_t chunkIdx = action.events.back().chunkIndex;

const SDFile *file = GetStructuredFile();

// it doesn't matter if this is an indirect sub command or an inlined 1-draw non-indirect count,
// either way the 'offset' is valid - either from the start, or updated for this particular draw
// when we originally patched (and fortunately that part doesn't change).
if(chunkIdx < file->chunks.size())
{
const SDChunk *chunk = file->chunks[chunkIdx];

ResourceId buf = chunk->FindChild("buffer")->AsResourceId();
uint64_t offs = chunk->FindChild("offset")->AsUInt64();

buf = GetResourceManager()->GetLiveID(buf);

bytebuf dispatchArgs;
GetBufferData(buf, offs, sizeof(VkDrawMeshTasksIndirectCommandEXT), dispatchArgs);

if(dispatchArgs.size() >= sizeof(VkDrawMeshTasksIndirectCommandEXT))
{
VkDrawMeshTasksIndirectCommandEXT *meshArgs =
(VkDrawMeshTasksIndirectCommandEXT *)dispatchArgs.data();

action.dispatchDimension[0] = meshArgs->groupCountX;
action.dispatchDimension[1] = meshArgs->groupCountY;
action.dispatchDimension[2] = meshArgs->groupCountZ;
}
}
}

uint32_t totalNumMeshlets =
action->dispatchDimension[0] * action->dispatchDimension[1] * action->dispatchDimension[2];
action.dispatchDimension[0] * action.dispatchDimension[1] * action.dispatchDimension[2];

const VulkanCreationInfo::Pipeline &pipeInfo = creationInfo.m_Pipeline[state.graphics.pipeline];

Expand Down Expand Up @@ -3180,7 +3214,7 @@ void VulkanReplay::FetchMeshOut(uint32_t eventId, VulkanRenderState &state)
modifiedstate.BeginRenderPassAndApplyState(m_pDriver, cmd, VulkanRenderState::BindGraphics,
false);

m_pDriver->ReplayDraw(cmd, *action);
m_pDriver->ReplayDraw(cmd, action);

modifiedstate.EndRenderPass(cmd);

Expand Down Expand Up @@ -3580,7 +3614,7 @@ void VulkanReplay::FetchMeshOut(uint32_t eventId, VulkanRenderState &state)
modifiedstate.BeginRenderPassAndApplyState(m_pDriver, cmd, VulkanRenderState::BindGraphics,
false);

m_pDriver->ReplayDraw(cmd, *action);
m_pDriver->ReplayDraw(cmd, action);

modifiedstate.EndRenderPass(cmd);

Expand Down Expand Up @@ -3963,6 +3997,8 @@ void VulkanReplay::FetchMeshOut(uint32_t eventId, VulkanRenderState &state)
// TODO handle multiple views
ret.taskout.numViews = 1;

ret.taskout.dispatchSize = action.dispatchDimension;

ret.taskout.vertStride = taskPayloadSize + sizeof(Vec4u);
ret.taskout.nearPlane = 0.0f;
ret.taskout.farPlane = 1.0f;
Expand Down Expand Up @@ -3992,6 +4028,8 @@ void VulkanReplay::FetchMeshOut(uint32_t eventId, VulkanRenderState &state)
// TODO handle multiple views
ret.meshout.numViews = 1;

ret.meshout.dispatchSize = action.dispatchDimension;

ret.meshout.vertStride = totalVertStride;
ret.meshout.nearPlane = nearp;
ret.meshout.farPlane = farp;
Expand Down Expand Up @@ -6182,6 +6220,8 @@ MeshFormat VulkanReplay::GetPostVSBuffers(uint32_t eventId, uint32_t instID, uin
ret.perPrimitiveStride = s.primStride;
ret.perPrimitiveOffset = s.primOffset;

ret.dispatchSize = s.dispatchSize;

if(stage == MeshDataStage::MeshOut)
{
ret.meshletSizes.resize(s.instData.size());
Expand Down
2 changes: 2 additions & 0 deletions renderdoc/driver/vulkan/vk_replay.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ struct VulkanPostVSData
VkDeviceMemory idxbufmem;
VkIndexType idxFmt;

rdcfixedarray<uint32_t, 3> dispatchSize;

bool hasPosOut;
bool flipY;

Expand Down
3 changes: 2 additions & 1 deletion renderdoc/replay/renderdoc_serialise.inl
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ void DoSerialise(SerialiserType &ser, MeshFormat &el)
SERIALISE_MEMBER(vertexByteStride);
SERIALISE_MEMBER(vertexByteSize);
SERIALISE_MEMBER(meshletSizes);
SERIALISE_MEMBER(dispatchSize);
SERIALISE_MEMBER(taskSizes);
SERIALISE_MEMBER(meshletIndexOffset);
SERIALISE_MEMBER(perPrimitiveOffset);
Expand All @@ -835,7 +836,7 @@ void DoSerialise(SerialiserType &ser, MeshFormat &el)
SERIALISE_MEMBER(showAlpha);
SERIALISE_MEMBER(status);

SIZE_CHECK(224);
SIZE_CHECK(240);
}

template <typename SerialiserType>
Expand Down

0 comments on commit 4bccf2d

Please sign in to comment.