From 0eedfa6e705af0813175d5126de5c3e76c592d9a Mon Sep 17 00:00:00 2001 From: Mateusz Szczygielski <112629916+msz-rai@users.noreply.github.com> Date: Sun, 19 Nov 2023 14:50:14 +0100 Subject: [PATCH] Fix nodes validation on rays change (#220) * Add test * Clear getGraphRunCtx if rays modified * Optimize * Fix optimization * Remove redundant fields existence checks * Review changes --- .../pcl/src/graph/DownSamplePointsNode.cpp | 5 +---- extensions/pcl/src/graph/NodesPcl.hpp | 1 - .../pcl/src/graph/VisualizePointsNode.cpp | 9 -------- src/api/apiCommon.hpp | 21 +++++++++++------- src/graph/GraphRunCtx.hpp | 10 +++++++++ .../graph/nodes/SetRingIdsRaysNodeTest.cpp | 22 +++++++++++++++---- 6 files changed, 42 insertions(+), 26 deletions(-) diff --git a/extensions/pcl/src/graph/DownSamplePointsNode.cpp b/extensions/pcl/src/graph/DownSamplePointsNode.cpp index b940b36d..54b67c55 100644 --- a/extensions/pcl/src/graph/DownSamplePointsNode.cpp +++ b/extensions/pcl/src/graph/DownSamplePointsNode.cpp @@ -24,10 +24,7 @@ using PCLPoint = pcl::PointXYZL; void DownSamplePointsNode::validateImpl() { IPointsNodeSingleInput::validateImpl(); - if (!input->hasField(XYZ_VEC3_F32)) { - auto msg = fmt::format("{} requires XYZ to be present", getName()); - throw InvalidPipeline(msg); - } + // Needed to clear cache because fields in the pipeline may have changed // In fact, the cache manager is no longer useful here // To be kept/removed in some future refactor (when resolving comment in the `enqueueExecImpl`) diff --git a/extensions/pcl/src/graph/NodesPcl.hpp b/extensions/pcl/src/graph/NodesPcl.hpp index e8b2d395..f9db4cf0 100644 --- a/extensions/pcl/src/graph/NodesPcl.hpp +++ b/extensions/pcl/src/graph/NodesPcl.hpp @@ -64,7 +64,6 @@ struct VisualizePointsNode : IPointsNodeSingleInput void setParameters(const char* windowName, int windowWidth, int windowHeight, bool fullscreen); // Node - void validateImpl() override; void enqueueExecImpl() override; // Node requirements diff --git a/extensions/pcl/src/graph/VisualizePointsNode.cpp b/extensions/pcl/src/graph/VisualizePointsNode.cpp index 16036c92..28001856 100644 --- a/extensions/pcl/src/graph/VisualizePointsNode.cpp +++ b/extensions/pcl/src/graph/VisualizePointsNode.cpp @@ -33,15 +33,6 @@ void VisualizePointsNode::setParameters(const char* windowName, int windowWidth, visualizeThread->visualizeNodes.push_back(std::dynamic_pointer_cast(shared_from_this())); } -void VisualizePointsNode::validateImpl() -{ - IPointsNodeSingleInput::validateImpl(); - if (!input->hasField(XYZ_VEC3_F32)) { - auto msg = fmt::format("{} requires XYZ to be present", getName()); - throw InvalidPipeline(msg); - } -} - // All calls to the viewers must be executed from the same thread void VisualizePointsNode::VisualizeThread::runVisualize() try { diff --git a/src/api/apiCommon.hpp b/src/api/apiCommon.hpp index 5cd824e8..c915d413 100644 --- a/src/api/apiCommon.hpp +++ b/src/api/apiCommon.hpp @@ -113,20 +113,25 @@ void createOrUpdateNode(rgl_node_t* nodeRawPtr, Args&&... args) } else { node = Node::validatePtr(*nodeRawPtr); } - // TODO: The magic below detects calls changing rgl_field_t* (e.g. FormatPointsNode) - // TODO: Such changes may require recomputing required fields in RaytraceNode. - // TODO: However, taking care of this manually is very bug prone. - // TODO: There are other ways to automate this, however, for now this should be enough. - bool fieldsModified = ((std::is_same_v> || ...)); - if (fieldsModified && node->hasGraphRunCtx()) { - node->getGraphRunCtx()->detachAndDestroy(); - } // As of now, there's no guarantee that changing node parameter won't influence other nodes // Therefore, before changing them, we need to ensure all nodes are idle (not running in GraphRunCtx). if (node->hasGraphRunCtx()) { node->getGraphRunCtx()->synchronize(); } + + // TODO: The magic below detects calls changing rgl_field_t* (e.g. FormatPointsNode) or changing rays definition + // TODO: Such changes may require recomputing required fields in RaytraceNode + // TODO: or performing validation in nodes dependent on ray count (e.g. SetRingIdsRaysNode) + // TODO: However, taking care of this manually is very bug prone. + // TODO: There are other ways to automate this, however, for now this should be enough. + bool fieldsModified = (std::is_same_v> || ...); + bool raysModified = std::is_same_v; + bool graphValidationNeeded = fieldsModified || raysModified; + if (graphValidationNeeded && node->hasGraphRunCtx()) { + node->getGraphRunCtx()->markNodesDirty(); + } + node->setParameters(std::forward(args)...); node->dirty = true; *nodeRawPtr = node.get(); diff --git a/src/graph/GraphRunCtx.hpp b/src/graph/GraphRunCtx.hpp index eac3d135..b8039771 100644 --- a/src/graph/GraphRunCtx.hpp +++ b/src/graph/GraphRunCtx.hpp @@ -63,6 +63,16 @@ struct GraphRunCtx */ void synchronizeNodeCPU(Node::ConstPtr nodeToSynchronize); + /** + * Marks all nodes dirty. + */ + void markNodesDirty() + { + for (auto&& node : nodes) { + node->dirty = true; + } + } + bool isThisThreadGraphThread() const { return maybeThread.has_value() && maybeThread->get_id() == std::this_thread::get_id(); diff --git a/test/src/graph/nodes/SetRingIdsRaysNodeTest.cpp b/test/src/graph/nodes/SetRingIdsRaysNodeTest.cpp index 6b7e8c7f..d6577630 100644 --- a/test/src/graph/nodes/SetRingIdsRaysNodeTest.cpp +++ b/test/src/graph/nodes/SetRingIdsRaysNodeTest.cpp @@ -12,14 +12,14 @@ class SetRingIdsNodeTest : public RGLTestWithParam void initializeRingNodeAndIds(int idsCount) { setRingIdsNode = nullptr; - std::vector ids(idsCount); - std::iota(ids.begin(), ids.end(), 0); - ringIds = ids; + ringIds.resize(idsCount); + std::iota(ringIds.begin(), ringIds.end(), 0); } void initializeRaysAndRaysNode(int rayCount) { rayNode = nullptr; + rays.clear(); rays.reserve(rayCount); for (int i = 0; i < rayCount; i++) { rays.emplace_back( @@ -64,8 +64,9 @@ TEST_P(SetRingIdsNodeTest, invalid_pipeline_less_rays_than_ring_ids) int32_t idsCount = GetParam(); if (idsCount / 2 == 0) { return; - }; + } + //// Incorrect number of ring ids passed to the rgl_node_rays_set_ring_ids //// initializeRingNodeAndIds(idsCount); initializeRaysAndRaysNode(idsCount / 2); ASSERT_RGL_SUCCESS(rgl_node_rays_from_mat3x4f(&rayNode, rays.data(), rays.size())); @@ -73,6 +74,19 @@ TEST_P(SetRingIdsNodeTest, invalid_pipeline_less_rays_than_ring_ids) ASSERT_RGL_SUCCESS(rgl_graph_node_add_child(rayNode, setRingIdsNode)); EXPECT_RGL_INVALID_PIPELINE(rgl_graph_run(setRingIdsNode), "ring ids doesn't match number of rays"); + + //// Changed number of rays between graph runs //// + // Initialize and run valid pipeline + initializeRingNodeAndIds(idsCount); + initializeRaysAndRaysNode(idsCount); + ASSERT_RGL_SUCCESS(rgl_node_rays_from_mat3x4f(&rayNode, rays.data(), rays.size())); + ASSERT_RGL_SUCCESS(rgl_node_rays_set_ring_ids(&setRingIdsNode, ringIds.data(), ringIds.size())); + ASSERT_RGL_SUCCESS(rgl_graph_node_add_child(rayNode, setRingIdsNode)); + EXPECT_RGL_SUCCESS(rgl_graph_run(setRingIdsNode)); + + // Make pipeline invalid + ASSERT_RGL_SUCCESS(rgl_node_rays_from_mat3x4f(&rayNode, rays.data(), rays.size() / 2)); + EXPECT_RGL_INVALID_PIPELINE(rgl_graph_run(setRingIdsNode), "ring ids doesn't match number of rays"); } TEST_P(SetRingIdsNodeTest, valid_pipeline_equal_number_of_rays_and_ring_ids)