Skip to content

Commit

Permalink
Draft of cluster RT support
Browse files Browse the repository at this point in the history
Mostly intended for perf analysis in the future

Had to hack the build because of some driver issue maybe.
  • Loading branch information
zeux committed Feb 3, 2025
1 parent 695bea2 commit b92a865
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 10 deletions.
2 changes: 1 addition & 1 deletion extern/volk
7 changes: 5 additions & 2 deletions src/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

// Validation is enabled by default in Debug
#ifndef NDEBUG
#define KHR_VALIDATION 1
#define KHR_VALIDATION 0
#else
#define KHR_VALIDATION CONFIG_RELVAL
#endif
Expand Down Expand Up @@ -243,7 +243,7 @@ VkPhysicalDevice pickPhysicalDevice(VkPhysicalDevice* physicalDevices, uint32_t
return result;
}

VkDevice createDevice(VkInstance instance, VkPhysicalDevice physicalDevice, uint32_t familyIndex, bool meshShadingSupported, bool raytracingSupported)
VkDevice createDevice(VkInstance instance, VkPhysicalDevice physicalDevice, uint32_t familyIndex, bool meshShadingSupported, bool raytracingSupported, bool clusterrtSupported)
{
float queuePriorities[] = { 1.0f };

Expand All @@ -266,6 +266,9 @@ VkDevice createDevice(VkInstance instance, VkPhysicalDevice physicalDevice, uint
extensions.push_back(VK_KHR_DEFERRED_HOST_OPERATIONS_EXTENSION_NAME);
}

if (clusterrtSupported)
extensions.push_back(VK_NV_CLUSTER_ACCELERATION_STRUCTURE_EXTENSION_NAME);

VkPhysicalDeviceFeatures2 features = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2 };
features.features.multiDrawIndirect = true;
features.features.pipelineStatisticsQuery = true;
Expand Down
2 changes: 1 addition & 1 deletion src/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ VkDebugReportCallbackEXT registerDebugCallback(VkInstance instance);
uint32_t getGraphicsFamilyIndex(VkPhysicalDevice physicalDevice);
VkPhysicalDevice pickPhysicalDevice(VkPhysicalDevice* physicalDevices, uint32_t physicalDeviceCount);

VkDevice createDevice(VkInstance instance, VkPhysicalDevice physicalDevice, uint32_t familyIndex, bool meshShadingSupported, bool raytracingSupported);
VkDevice createDevice(VkInstance instance, VkPhysicalDevice physicalDevice, uint32_t familyIndex, bool meshShadingSupported, bool raytracingSupported, bool clusterrtSupported);
23 changes: 19 additions & 4 deletions src/niagara.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,13 @@ int main(int argc, const char** argv)

bool meshShadingSupported = false;
bool raytracingSupported = false;
bool clusterrtSupported = false;

for (auto& ext : extensions)
{
meshShadingSupported = meshShadingSupported || strcmp(ext.extensionName, VK_EXT_MESH_SHADER_EXTENSION_NAME) == 0;
raytracingSupported = raytracingSupported || strcmp(ext.extensionName, VK_KHR_RAY_QUERY_EXTENSION_NAME) == 0;
clusterrtSupported = clusterrtSupported || strcmp(ext.extensionName, VK_NV_CLUSTER_ACCELERATION_STRUCTURE_EXTENSION_NAME) == 0;
}

meshShadingEnabled = meshShadingSupported;
Expand All @@ -387,7 +389,7 @@ int main(int argc, const char** argv)
uint32_t familyIndex = getGraphicsFamilyIndex(physicalDevice);
assert(familyIndex != VK_QUEUE_FAMILY_IGNORED);

VkDevice device = createDevice(instance, physicalDevice, familyIndex, meshShadingSupported, raytracingSupported);
VkDevice device = createDevice(instance, physicalDevice, familyIndex, meshShadingSupported, raytracingSupported, clusterrtSupported);
assert(device);

volkLoadDevice(device);
Expand Down Expand Up @@ -816,9 +818,22 @@ int main(int argc, const char** argv)
Buffer tlasInstanceBuffer = {};
if (raytracingSupported)
{
std::vector<VkDeviceSize> compactedSizes;
buildBLAS(device, geometry.meshes, vb, ib, blas, compactedSizes, blasBuffer, commandPool, commandBuffer, queue, memoryProperties);
compactBLAS(device, blas, compactedSizes, blasBuffer, commandPool, commandBuffer, queue, memoryProperties);
if (clusterrtSupported)
{
Buffer vxb = {};
createBuffer(vxb, device, memoryProperties, geometry.meshletvtx0.size() * sizeof(uint16_t), VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
memcpy(vxb.data, geometry.meshletvtx0.data(), geometry.meshletvtx0.size() * sizeof(uint16_t));

buildCLAS(device, geometry.meshes, geometry.meshlets, vxb, mdb, blas, blasBuffer, commandPool, commandBuffer, queue, memoryProperties);

destroyBuffer(vxb, device);
}
else
{
std::vector<VkDeviceSize> compactedSizes;
buildBLAS(device, geometry.meshes, vb, ib, blas, compactedSizes, blasBuffer, commandPool, commandBuffer, queue, memoryProperties);
compactBLAS(device, blas, compactedSizes, blasBuffer, commandPool, commandBuffer, queue, memoryProperties);
}

blasAddresses.resize(blas.size());

Expand Down
21 changes: 19 additions & 2 deletions src/scene.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <memory>
#include <cstring>

static size_t appendMeshlets(Geometry& result, const std::vector<vec3>& vertices, const std::vector<uint32_t>& indices, uint32_t baseVertex, bool fast = false)
static size_t appendMeshlets(Geometry& result, const std::vector<vec3>& vertices, std::vector<uint32_t>& indices, uint32_t baseVertex, bool lod0, bool fast = false)
{
const size_t max_vertices = MESH_MAXVTX;
const size_t max_triangles = MESH_MAXTRI;
Expand Down Expand Up @@ -58,6 +58,23 @@ static size_t appendMeshlets(Geometry& result, const std::vector<vec3>& vertices
for (unsigned int i = 0; i < indexGroupCount; ++i)
result.meshletdata.push_back(indexGroups[i]);

if (lod0)
{
for (unsigned int i = 0; i < meshlet.vertex_count; ++i)
{
unsigned int vtx = meshlet_vertices[meshlet.vertex_offset + i];

unsigned short hx = meshopt_quantizeHalf(vertices[vtx].x);
unsigned short hy = meshopt_quantizeHalf(vertices[vtx].y);
unsigned short hz = meshopt_quantizeHalf(vertices[vtx].z);

result.meshletvtx0.push_back(hx);
result.meshletvtx0.push_back(hy);
result.meshletvtx0.push_back(hz);
result.meshletvtx0.push_back(0);
}
}

meshopt_Bounds bounds = meshopt_computeMeshletBounds(&meshlet_vertices[meshlet.vertex_offset], &meshlet_triangles[meshlet.triangle_offset], meshlet.triangle_count, &vertices[0].x, vertices.size(), sizeof(vec3));

Meshlet m = {};
Expand Down Expand Up @@ -203,7 +220,7 @@ static void appendMesh(Geometry& result, std::vector<Vertex>& vertices, std::vec
result.indices.insert(result.indices.end(), lodIndices.begin(), lodIndices.end());

lod.meshletOffset = uint32_t(result.meshlets.size());
lod.meshletCount = buildMeshlets ? uint32_t(appendMeshlets(result, positions, lodIndices, mesh.vertexOffset, fast)) : 0;
lod.meshletCount = buildMeshlets ? uint32_t(appendMeshlets(result, positions, lodIndices, mesh.vertexOffset, &lod == mesh.lods, fast)) : 0;

lod.error = lodError * lodScale;

Expand Down
1 change: 1 addition & 0 deletions src/scene.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ struct Geometry
std::vector<uint32_t> indices;
std::vector<Meshlet> meshlets;
std::vector<uint32_t> meshletdata;
std::vector<uint16_t> meshletvtx0; // 4 position components per vertex referenced by meshlets in lod 0, packed tightly
std::vector<Mesh> meshes;
};

Expand Down
220 changes: 220 additions & 0 deletions src/scenert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include "scene.h"
#include "resources.h"

#include "config.h"

#include <string.h>

void buildBLAS(VkDevice device, const std::vector<Mesh>& meshes, const Buffer& vb, const Buffer& ib, std::vector<VkAccelerationStructureKHR>& blas, std::vector<VkDeviceSize>& compactedSizes, Buffer& blasBuffer, VkCommandPool commandPool, VkCommandBuffer commandBuffer, VkQueue queue, const VkPhysicalDeviceMemoryProperties& memoryProperties)
Expand Down Expand Up @@ -224,6 +226,224 @@ void compactBLAS(VkDevice device, std::vector<VkAccelerationStructureKHR>& blas,
blasBuffer = compactedBuffer;
}

void buildCLAS(VkDevice device, const std::vector<Mesh>& meshes, const std::vector<Meshlet>& meshlets, const Buffer& vxb, const Buffer& mdb, std::vector<VkAccelerationStructureKHR>& blas, Buffer& blasBuffer, VkCommandPool commandPool, VkCommandBuffer commandBuffer, VkQueue queue, const VkPhysicalDeviceMemoryProperties& memoryProperties)
{
VkClusterAccelerationStructureTriangleClusterInputNV clusterSizes = { VK_STRUCTURE_TYPE_CLUSTER_ACCELERATION_STRUCTURE_TRIANGLE_CLUSTER_INPUT_NV };
clusterSizes.vertexFormat = VK_FORMAT_R16G16B16A16_SFLOAT;
clusterSizes.maxGeometryIndexValue = 0;
clusterSizes.maxClusterUniqueGeometryCount = 1;
clusterSizes.maxClusterTriangleCount = MESH_MAXTRI;
clusterSizes.maxClusterVertexCount = MESH_MAXVTX;
clusterSizes.minPositionTruncateBitCount = 0;

VkClusterAccelerationStructureInputInfoNV clusterInfo = { VK_STRUCTURE_TYPE_CLUSTER_ACCELERATION_STRUCTURE_INPUT_INFO_NV };
clusterInfo.maxAccelerationStructureCount = 0;
clusterInfo.flags = VK_BUILD_ACCELERATION_STRUCTURE_PREFER_FAST_TRACE_BIT_KHR | VK_BUILD_ACCELERATION_STRUCTURE_ALLOW_COMPACTION_BIT_KHR;
clusterInfo.opType = VK_CLUSTER_ACCELERATION_STRUCTURE_OP_TYPE_BUILD_TRIANGLE_CLUSTER_NV;
clusterInfo.opMode = VK_CLUSTER_ACCELERATION_STRUCTURE_OP_MODE_IMPLICIT_DESTINATIONS_NV;
clusterInfo.opInput.pTriangleClusters = &clusterSizes;

size_t maxClustersPerMesh = 0;

for (const Mesh& mesh : meshes)
{
clusterSizes.maxTotalTriangleCount += mesh.lods[0].indexCount / 3;
clusterSizes.maxTotalVertexCount += mesh.vertexCount;
clusterInfo.maxAccelerationStructureCount += mesh.lods[0].meshletCount;
maxClustersPerMesh = std::max(maxClustersPerMesh, size_t(mesh.lods[0].meshletCount));
}

VkClusterAccelerationStructureClustersBottomLevelInputNV accelSizes = { VK_STRUCTURE_TYPE_CLUSTER_ACCELERATION_STRUCTURE_CLUSTERS_BOTTOM_LEVEL_INPUT_NV };
accelSizes.maxTotalClusterCount = clusterInfo.maxAccelerationStructureCount;
accelSizes.maxClusterCountPerAccelerationStructure = maxClustersPerMesh;

VkClusterAccelerationStructureInputInfoNV accelInfo = { VK_STRUCTURE_TYPE_CLUSTER_ACCELERATION_STRUCTURE_INPUT_INFO_NV };
accelInfo.maxAccelerationStructureCount = meshes.size();
accelInfo.flags = VK_BUILD_ACCELERATION_STRUCTURE_PREFER_FAST_TRACE_BIT_KHR | VK_BUILD_ACCELERATION_STRUCTURE_ALLOW_COMPACTION_BIT_KHR;
accelInfo.opType = VK_CLUSTER_ACCELERATION_STRUCTURE_OP_TYPE_BUILD_CLUSTERS_BOTTOM_LEVEL_NV;
accelInfo.opMode = VK_CLUSTER_ACCELERATION_STRUCTURE_OP_MODE_IMPLICIT_DESTINATIONS_NV;
accelInfo.opInput.pClustersBottomLevel = &accelSizes;

VkAccelerationStructureBuildSizesInfoKHR csizeInfo = { VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR };
vkGetClusterAccelerationStructureBuildSizesNV(device, &clusterInfo, &csizeInfo);

VkAccelerationStructureBuildSizesInfoKHR bsizeInfo = { VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR };
vkGetClusterAccelerationStructureBuildSizesNV(device, &accelInfo, &bsizeInfo);

printf("CLAS accelerationStructureSize: %.2f MB, scratchSize: %.2f MB\n", double(csizeInfo.accelerationStructureSize) / 1e6, double(csizeInfo.buildScratchSize) / 1e6);
printf("CBLAS accelerationStructureSize: %.2f MB, scratchSize: %.2f MB\n", double(bsizeInfo.accelerationStructureSize) / 1e6, double(bsizeInfo.buildScratchSize) / 1e6);

Buffer clasBuffer;
createBuffer(clasBuffer, device, memoryProperties, csizeInfo.accelerationStructureSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);

createBuffer(blasBuffer, device, memoryProperties, bsizeInfo.accelerationStructureSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);

Buffer scratchBuffer;
createBuffer(scratchBuffer, device, memoryProperties, std::max(bsizeInfo.buildScratchSize, csizeInfo.buildScratchSize), VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);

Buffer infosBuffer;
createBuffer(infosBuffer, device, memoryProperties, std::max(clusterInfo.maxAccelerationStructureCount * sizeof(VkClusterAccelerationStructureBuildTriangleClusterInfoNV), accelInfo.maxAccelerationStructureCount * sizeof(VkClusterAccelerationStructureBuildClustersBottomLevelInfoNV)), VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);

VkDeviceAddress mdbAddress = getBufferAddress(mdb, device);
VkDeviceAddress vxbAddress = getBufferAddress(vxb, device);

VkClusterAccelerationStructureBuildTriangleClusterInfoNV* clusterData = static_cast<VkClusterAccelerationStructureBuildTriangleClusterInfoNV*>(infosBuffer.data);
size_t vxbOffset = 0;

for (const Mesh& mesh : meshes)
{
for (size_t mi = 0; mi < mesh.lods[0].meshletCount; ++mi)
{
const Meshlet& ml = meshlets[mesh.lods[0].meshletOffset + mi];

VkClusterAccelerationStructureBuildTriangleClusterInfoNV cluster = {};
cluster.clusterID = uint32_t(mi);
cluster.triangleCount = ml.triangleCount;
cluster.vertexCount = ml.vertexCount;
cluster.positionTruncateBitCount = 0;
cluster.indexType = VK_CLUSTER_ACCELERATION_STRUCTURE_INDEX_FORMAT_8BIT_NV;
cluster.vertexBufferStride = sizeof(uint16_t) * 4;
cluster.indexBuffer = mdbAddress + (ml.dataOffset + (ml.shortRefs ? (ml.vertexCount + 1) / 2 : ml.vertexCount)) * sizeof(uint32_t);
cluster.vertexBuffer = vxbAddress + vxbOffset;

memcpy(clusterData, &cluster, sizeof(VkClusterAccelerationStructureBuildTriangleClusterInfoNV));
clusterData++;
vxbOffset += ml.vertexCount * (sizeof(uint16_t) * 4);
}
}

Buffer rangeBuffer;
// todo host vis -> device local?
createBuffer(rangeBuffer, device, memoryProperties, (clusterInfo.maxAccelerationStructureCount + accelInfo.maxAccelerationStructureCount) * 16, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);

VkClusterAccelerationStructureCommandsInfoNV clusterBuild = { VK_STRUCTURE_TYPE_CLUSTER_ACCELERATION_STRUCTURE_COMMANDS_INFO_NV };
clusterBuild.input = clusterInfo;
clusterBuild.dstImplicitData = getBufferAddress(clasBuffer, device);
clusterBuild.scratchData = getBufferAddress(scratchBuffer, device);
clusterBuild.dstAddressesArray.deviceAddress = getBufferAddress(rangeBuffer, device);
clusterBuild.dstAddressesArray.size = clusterInfo.maxAccelerationStructureCount * 16;
clusterBuild.dstAddressesArray.stride = 16;
clusterBuild.dstSizesArray.deviceAddress = getBufferAddress(rangeBuffer, device) + 8;
clusterBuild.dstSizesArray.size = clusterInfo.maxAccelerationStructureCount * 16 - 8;
clusterBuild.dstSizesArray.stride = 16;
clusterBuild.srcInfosArray.deviceAddress = getBufferAddress(infosBuffer, device);
clusterBuild.srcInfosArray.size = clusterInfo.maxAccelerationStructureCount * sizeof(VkClusterAccelerationStructureBuildTriangleClusterInfoNV);

VK_CHECK(vkResetCommandPool(device, commandPool, 0));

VkCommandBufferBeginInfo beginInfo = { VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO };
beginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;

VK_CHECK(vkBeginCommandBuffer(commandBuffer, &beginInfo));

vkCmdBuildClusterAccelerationStructureIndirectNV(commandBuffer, &clusterBuild);

VK_CHECK(vkEndCommandBuffer(commandBuffer));

VkSubmitInfo submitInfo = { VK_STRUCTURE_TYPE_SUBMIT_INFO };
submitInfo.commandBufferCount = 1;
submitInfo.pCommandBuffers = &commandBuffer;

VK_CHECK(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE));
VK_CHECK(vkDeviceWaitIdle(device));

size_t totalClusterSize = 0;
for (size_t i = 0; i < clusterInfo.maxAccelerationStructureCount; ++i)
{
uint64_t addr = ((uint64_t*)rangeBuffer.data)[i * 2];
uint32_t size = ((uint32_t*)rangeBuffer.data)[i * 4 + 2];

// printf("cluster %d: %16lx %08x\n", int(i), addr, size);

totalClusterSize += ((uint32_t*)rangeBuffer.data)[i * 4 + 2];
}

printf("CLAS actual total accelerationStructureSize: %.2f MB\n", double(totalClusterSize) / 1e6);

VkClusterAccelerationStructureBuildClustersBottomLevelInfoNV* accelData = static_cast<VkClusterAccelerationStructureBuildClustersBottomLevelInfoNV*>(infosBuffer.data);
size_t accelOffset = 0;

printf("max cluster count %d\n", int(maxClustersPerMesh));

for (const Mesh& mesh : meshes)
{
VkClusterAccelerationStructureBuildClustersBottomLevelInfoNV accel = {};
accel.clusterReferencesCount = uint32_t(mesh.lods[0].meshletCount);
accel.clusterReferencesStride = 16;
accel.clusterReferences = getBufferAddress(rangeBuffer, device) + accelOffset * 16;

// NOTE: HUGE HACK!!!
// ON NV 3050 we have a single mesh with over 2000 clusters in Bistro scene
// if we keep the entire cluster set in the build, the GPU will hang during build
// we will print here the first cluster we skip from the mesh; not sure what is going on really.
if (accel.clusterReferencesCount > 2000)
{
// IMPORTANT:
// This is definitely not a "broken" cluster: we can shift the range of clusters we build here by a couple clusters forward, and it still works
// accel.clusterReferences += 32;
// ... but if we get just one extra cluster (1537 below), GPU hangs
// It's suspicious that 1536 == 1024 + 512 but who knows what's going on.
accel.clusterReferencesCount = 1536;
printf("HACK: in mesh %d we are going to skip clusters starting from %d\n", int(&mesh - meshes.data()), int(accel.clusterReferencesCount));
printf("SKIP: cluster blas %016lx size %04x tris %d vertices %d\n",
((uint64_t*)rangeBuffer.data)[(accelOffset + accel.clusterReferencesCount) * 2],
((uint32_t*)rangeBuffer.data)[(accelOffset + accel.clusterReferencesCount) * 4 + 2],
meshlets[mesh.lods[0].meshletOffset + accel.clusterReferencesCount].triangleCount,
meshlets[mesh.lods[0].meshletOffset + accel.clusterReferencesCount].vertexCount);
}

// printf("mesh %d: offset %d count %d\n", int(&mesh - meshes.data()), int(accelOffset), mesh.lods[0].meshletCount);

memcpy(accelData, &accel, sizeof(VkClusterAccelerationStructureBuildClustersBottomLevelInfoNV));
accelData++;
accelOffset += mesh.lods[0].meshletCount;
}

VkClusterAccelerationStructureCommandsInfoNV accelBuild = { VK_STRUCTURE_TYPE_CLUSTER_ACCELERATION_STRUCTURE_COMMANDS_INFO_NV };
accelBuild.input = accelInfo;
accelBuild.dstImplicitData = getBufferAddress(blasBuffer, device);
accelBuild.scratchData = getBufferAddress(scratchBuffer, device);
accelBuild.dstAddressesArray.deviceAddress = getBufferAddress(rangeBuffer, device) + clusterInfo.maxAccelerationStructureCount * 16;
accelBuild.dstAddressesArray.size = accelInfo.maxAccelerationStructureCount * 16;
accelBuild.dstAddressesArray.stride = 16;
accelBuild.dstSizesArray.deviceAddress = getBufferAddress(rangeBuffer, device) + clusterInfo.maxAccelerationStructureCount * 16 + 8;
accelBuild.dstSizesArray.size = accelInfo.maxAccelerationStructureCount * 16 - 8;
accelBuild.dstSizesArray.stride = 16;
accelBuild.srcInfosArray.deviceAddress = getBufferAddress(infosBuffer, device);
accelBuild.srcInfosArray.size = accelInfo.maxAccelerationStructureCount * sizeof(VkClusterAccelerationStructureBuildClustersBottomLevelInfoNV);

VK_CHECK(vkResetCommandPool(device, commandPool, 0));
VK_CHECK(vkBeginCommandBuffer(commandBuffer, &beginInfo));

vkCmdBuildClusterAccelerationStructureIndirectNV(commandBuffer, &accelBuild);

VK_CHECK(vkEndCommandBuffer(commandBuffer));
VK_CHECK(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE));
VK_CHECK(vkDeviceWaitIdle(device));

VkDeviceAddress blasAddress = getBufferAddress(blasBuffer, device);
uint32_t* rangeAccel = (uint32_t*)rangeBuffer.data + clusterInfo.maxAccelerationStructureCount * 4;

blas.resize(meshes.size());

for (size_t i = 0; i < accelInfo.maxAccelerationStructureCount; ++i)
{
VkAccelerationStructureCreateInfoKHR accelerationInfo = { VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR };
accelerationInfo.buffer = blasBuffer.buffer;
accelerationInfo.offset = ((uint64_t*)rangeAccel)[i * 2] - blasAddress;
accelerationInfo.size = rangeAccel[i * 4 + 2];
accelerationInfo.type = VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR;

VK_CHECK(vkCreateAccelerationStructureKHR(device, &accelerationInfo, nullptr, &blas[i]));
}

destroyBuffer(scratchBuffer, device);
destroyBuffer(infosBuffer, device);
destroyBuffer(rangeBuffer, device);

// TODO: destroyBuffer(clasBuffer, device);
}

void fillInstanceRT(VkAccelerationStructureInstanceKHR& instance, const MeshDraw& draw, uint32_t instanceIndex, VkDeviceAddress blas)
{
mat3 xform = transpose(glm::mat3_cast(draw.orientation)) * draw.scale;
Expand Down
Loading

0 comments on commit b92a865

Please sign in to comment.