-
Notifications
You must be signed in to change notification settings - Fork 174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add TensorRT support for GNNs #4016
base: main
Are you sure you want to change the base?
feat: Add TensorRT support for GNNs #4016
Conversation
WalkthroughA new job, Changes
Possibly related PRs
Suggested Labels
Suggested Reviewers
Poem
✨ Finishing Touches
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🧹 Nitpick comments (4)
Plugins/ExaTrkX/src/TensorRTEdgeClassifier.cpp (1)
98-100
: Use ACTS logging instead ofstd::cout
, prefer you should.For consistency within the codebase, replace
std::cout
with ACTS logging macros.Apply this diff to use the logging framework:
~TimePrinter() { - std::cout << name << ": " << milliseconds(t0, t1) << std::endl; + ACTS_INFO(name << ": " << milliseconds(t0, t1) << " ms"); }Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp (2)
38-41
: Destructor to be markedoverride
, consider you should.Since the base class has a virtual destructor, marking the destructor in the derived class with
override
good practice it is.Apply this diff for clarity:
~TensorRTEdgeClassifier(); + ~TensorRTEdgeClassifier() override;
49-58
: Member variables' initialization order, ensure you must.Initialize member variables in the order they are declared to avoid warnings.
Ensure that
m_cfg
is initialized beforem_trtLogger
, as declared.Examples/Python/src/ExaTrkXTrackFinding.cpp (1)
110-128
: Logger name, more specific make you should.For clarity and consistency, use a distinct logger name for
TensorRTEdgeClassifier
.Apply this diff to specify the logger name:
return std::make_shared<Alg>( - c, getDefaultLogger("EdgeClassifier", lvl)); + c, getDefaultLogger("TensorRTEdgeClassifier", lvl));
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
.gitlab-ci.yml
(1 hunks)Examples/Python/src/ExaTrkXTrackFinding.cpp
(2 hunks)Plugins/ExaTrkX/CMakeLists.txt
(1 hunks)Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp
(1 hunks)Plugins/ExaTrkX/src/TensorRTEdgeClassifier.cpp
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (6)
- GitHub Check: CI Bridge / build_gnn_tensorrt
- GitHub Check: linux_physmon
- GitHub Check: linux_examples_test
- GitHub Check: missing_includes
- GitHub Check: linux_ubuntu_extra (ubuntu2204_clang, 20)
- GitHub Check: build_debug
🔇 Additional comments (3)
Examples/Python/src/ExaTrkXTrackFinding.cpp (1)
126-126
: Missing configuration memberuseEdgeFeatures
, verify you should.Inconsistent the configuration is with other classifiers. Include
useEdgeFeatures
if required.Ensure that all necessary configuration options are included.
Plugins/ExaTrkX/CMakeLists.txt (2)
Line range hint
1-38
: Well-structured, this CMake configuration is!Follow consistent patterns for different backends, it does. Proper organization and clarity, I sense.
26-38
: Version constraints for TensorRT, specify we must!Hmmmm, missing version constraints for TensorRT package, I see. Dangerous this can be, yes. Compatibility issues, it may cause.
Apply this change, you should:
- find_package(TensorRT REQUIRED) + find_package(TensorRT 8.6 REQUIRED)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
cmake/FindTensorRT.cmake (4)
1-28
: Hmmmm, missing minimum version requirements, I sense.Document the minimum required TensorRT version for this module, you should. Help users avoid compatibility issues, it will.
30-34
: Validate components list, we must.Check for invalid component names in
TensorRT_FIND_COMPONENTS
, wise it would be. Prevent configuration errors early, this will.if(NOT TensorRT_FIND_COMPONENTS) set(TensorRT_FIND_COMPONENTS nvinfer nvinfer_plugin nvonnxparser nvparsers) endif() +set(_valid_components nvinfer nvinfer_plugin nvonnxparser nvparsers) +foreach(component IN LISTS TensorRT_FIND_COMPONENTS) + if(NOT component IN_LIST _valid_components) + message(FATAL_ERROR "Invalid TensorRT component specified: ${component}") + endif() +endforeach() set(TensorRT_LIBRARIES)
44-53
: More helpful error message, provide we should.Include the searched paths in the error message, helpful it would be. Guide users to correct configuration, this will.
if(TensorRT_FIND_REQUIRED) message( FATAL_ERROR - "Fail to find TensorRT, please set TensorRT_ROOT. Include path not found." + "Failed to find TensorRT header NvInfer.h. Searched in:\n" + " - ${TensorRT_ROOT}/include\n" + " - $ENV{TensorRT_ROOT}/include\n" + "Please set TensorRT_ROOT to the installation directory." ) endif()
171-174
: Process components in parallel, consider we should.For large component lists, parallel processing could speed up configuration. Optional it is, but beneficial for large projects.
+if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.19") + cmake_policy(SET CMP0114 NEW) + foreach(component IN LISTS TensorRT_FIND_COMPONENTS) + cmake_language(DEFER DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + CALL _find_trt_component ${component}) + endforeach() +else() foreach(component IN LISTS TensorRT_FIND_COMPONENTS) _find_trt_component(${component}) endforeach() +endif()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
Examples/Python/src/ExaTrkXTrackFinding.cpp
(2 hunks)cmake/FindTensorRT.cmake
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- Examples/Python/src/ExaTrkXTrackFinding.cpp
⏰ Context from checks skipped due to timeout of 90000ms (8)
- GitHub Check: linux_ubuntu_extra (ubuntu2204_clang, 20)
- GitHub Check: linux_ubuntu_extra (ubuntu2204, 20)
- GitHub Check: macos
- GitHub Check: linux_ubuntu
- GitHub Check: unused_files
- GitHub Check: missing_includes
- GitHub Check: build_debug
- GitHub Check: CI Bridge / build_gnn_tensorrt
🔇 Additional comments (2)
cmake/FindTensorRT.cmake (2)
176-182
: Well implemented, the package handling is.Following CMake best practices, you are. Approve this section, I do.
56-114
: Handle malformed version strings, we must.Verify that version strings are properly extracted, essential it is. Add error handling for malformed version strings, we should.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
.gitlab-ci.yml (1)
189-224
:⚠️ Potential issueAlign with existing ExaTrkX jobs and add missing configurations, you must!
Similar to the past review, improvements needed there are:
- CUDA architectures with other ExaTrkX jobs, align we must
- Artifacts for downstream jobs, configure we should
- Testing stage for TensorRT functionality, define we must
- Torch disabled while TensorRT enabled, verify this approach we should
Apply these changes, you should:
build_gnn_tensorrt: stage: build image: ghcr.io/acts-project/ubuntu2404_tensorrt:74 variables: DEPENDENCY_URL: https://acts.web.cern.ch/ACTS/ci/ubuntu-24.04/deps.$DEPENDENCY_TAG.tar.zst + TORCH_CUDA_ARCH_LIST: "8.0 8.6 8.9 9.0" cache: key: ccache-${CI_JOB_NAME}-${CI_COMMIT_REF_SLUG}-${CCACHE_KEY_SUFFIX} fallback_keys: - ccache-${CI_JOB_NAME}-${CI_DEFAULT_BRANCH}-${CCACHE_KEY_SUFFIX} when: always paths: - ${CCACHE_DIR} + artifacts: + paths: + - build/ + exclude: + - build/**/*.o + expire_in: 6 hours tags: - docker-gpu-nvidia script: - git clone $CLONE_URL src - cd src - git checkout $HEAD_SHA - source CI/dependencies.sh - cd .. - mkdir build - > cmake -B build -S src -DACTS_BUILD_PLUGIN_EXATRKX=ON - -DACTS_EXATRKX_ENABLE_TORCH=OFF -DACTS_EXATRKX_ENABLE_CUDA=ON -DACTS_EXATRKX_ENABLE_TENSORRT=ON -DPython_EXECUTABLE=$(which python3) -DCMAKE_CUDA_ARCHITECTURES="75;86" - ccache -z - cmake --build build -- -j6 - ccache -s +test_gnn_tensorrt: + stage: test + needs: + - build_gnn_tensorrt + image: ghcr.io/acts-project/ubuntu2404_tensorrt:74 + tags: + - docker-gpu-nvidia + script: + - apt-get update -y + - git clone $CLONE_URL src + - cd src + - git checkout $HEAD_SHA + - source CI/dependencies.sh + - cd .. + - ctest --test-dir build -R TensorRT
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
.gitlab-ci.yml
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (9)
- GitHub Check: merge-sentinel
- GitHub Check: CI Bridge / build_gnn_tensorrt
- GitHub Check: unused_files
- GitHub Check: macos
- GitHub Check: linux_ubuntu_extra (ubuntu2204_clang, 20)
- GitHub Check: missing_includes
- GitHub Check: build_debug
- GitHub Check: linux_ubuntu
- GitHub Check: docs
🔇 Additional comments (1)
.gitlab-ci.yml (1)
191-191
: Verify the image tag, we must!Different from other jobs using tag '63', this one uses tag '74'. Intentional this change is?
✅ Verification successful
Valid and intentional, this image tag is!
Created recently, the image tag '74' has been. Tagged as 'latest', it is, hmmmm. Trust in the Force and the wisdom of the maintainers, we must.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Check if the image exists and compare its creation date with other images gh api \ -H "Accept: application/vnd.github+json" \ -H "X-GitHub-Api-Version: 2022-11-28" \ /orgs/acts-project/packages/container/ubuntu2404_tensorrt/versions | jq '.[] | select(.metadata.container.tags[] | contains("74"))'Length of output: 2322
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp (1)
49-49
:⚠️ Potential issueReturn a proper torch::Device object, you must.
Return device type directly, you do. A proper torch::Device instance, needed it is.
Apply this diff to return the correct device:
- torch::Device device() const override { return torch::kCUDA; }; + torch::Device device() const override { return torch::Device(torch::kCUDA); };
🧹 Nitpick comments (1)
Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp (1)
29-38
: Document the configuration parameters, you must.Mysterious, these magic numbers are. Documentation for
cut = 0.21
and validation fordeviceID
, needed they are. Add comments explaining:
- The meaning and impact of the cut threshold
- Valid ranges for deviceID
- Purpose of each configuration parameter
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (7)
- GitHub Check: linux_examples_test
- GitHub Check: linux_physmon
- GitHub Check: linux_ubuntu_extra (ubuntu2204_clang, 20)
- GitHub Check: macos
- GitHub Check: build_debug
- GitHub Check: CI Bridge / build_gnn_tensorrt
- GitHub Check: missing_includes
🔇 Additional comments (2)
Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp (2)
61-62
: Wise use of mutex for thread safety, I sense.Protected well, the execution contexts are. Thread-safe design, you have achieved.
18-23
: Verify TensorRT headers in CI environment, we must.Failed to compile in CI, the PR has. Check if TensorRT headers are available in CI environment, we should.
✅ Verification successful
Expected behavior in the Force, this is. Worry, you must not.
In specialized container with TensorRT, the CI runs. Missing headers in standard paths, natural this is. When building locally, set
TensorRT_ROOT
to your installation path, you must. In CI environment, the dedicated containerghcr.io/acts-project/ubuntu2404_tensorrt:74
handles this automatically, it does.🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for TensorRT headers in system paths find /usr/include -name "NvInfer.h" 2>/dev/null || echo "TensorRT headers not found in /usr/include" find /usr/local/include -name "NvInfer.h" 2>/dev/null || echo "TensorRT headers not found in /usr/local/include" # Check if TensorRT is mentioned in package configuration if command -v pkg-config &> /dev/null; then pkg-config --list-all | grep -i tensorrt || echo "No TensorRT package config found" fiLength of output: 303
Script:
#!/bin/bash # Check CMake configuration for TensorRT fd CMakeLists.txt --exec grep -l "TensorRT" {} \; fd "*.cmake" --exec grep -l "TensorRT" {} \; # Check documentation for TensorRT requirements fd "README|INSTALL" --extension md --exec grep -i -A 5 -B 5 "tensorrt" {} \; # Check environment variable usage in build files rg "TENSORRT|TensorRT" --type cmakeLength of output: 6096
Script:
#!/bin/bash # Check build configuration options fd "BuildOptions.cmake|CMakeCache.txt" --exec grep -A 5 -B 5 "EXATRKX.*TENSORRT" {} \; # Look for any build documentation rg -i "tensorrt|cuda" docs/ || echo "No direct mentions in docs/" # Check CI configuration for TensorRT setup fd "\.gitlab-ci\.yml|\.github/workflows" --exec grep -A 10 -B 10 -i "tensorrt" {} \;Length of output: 2251
Quality Gate passedIssues Measures |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (1)
Plugins/ExaTrkX/src/TensorRTEdgeClassifier.cpp (1)
37-37
: Const correctness in log method, improve we must.The
log
method should markmsg
parameter asconst char* const
, it should.- void log(Severity severity, const char *msg) noexcept override { + void log(Severity severity, const char* const msg) noexcept override {
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp
(1 hunks)Plugins/ExaTrkX/src/TensorRTEdgeClassifier.cpp
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp
🧰 Additional context used
📓 Learnings (1)
Plugins/ExaTrkX/src/TensorRTEdgeClassifier.cpp (1)
Learnt from: benjaminhuth
PR: acts-project/acts#4016
File: Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp:44-46
Timestamp: 2025-01-21T10:16:57.820Z
Learning: In Acts project, when a derived class implements a virtual method, the documentation can be found in the base class and doesn't need to be repeated. For example, TensorRTEdgeClassifier's operator() documentation is in EdgeClassificationBase.
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: CI Bridge / build_gnn_tensorrt
- GitHub Check: merge-sentinel
🔇 Additional comments (1)
Plugins/ExaTrkX/src/TensorRTEdgeClassifier.cpp (1)
181-183
:⚠️ Potential issueFix undefined variable and improve memory management, we must.
Variable 'outputMem' undefined it is, and memory management improved it must be.
- auto scores = torch::from_blob( - outputMem, edgeIndex.size(1), 1, [](void *ptr) { cudaFree(ptr); }, - torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); + // scores tensor already created and used for inference + // no need to create new tensor from blobLikely invalid or redundant comment.
std::size_t freeMem, totalMem; | ||
cudaMemGetInfo(&freeMem, &totalMem); | ||
ACTS_DEBUG("Used CUDA memory after TensorRT initialization: " | ||
<< (totalMem - freeMem) * 1e-9 << " / " << totalMem * 1e-9 | ||
<< " GB"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check cudaMemGetInfo status, you must.
Handle errors from CUDA memory information retrieval, we should.
std::size_t freeMem, totalMem;
-cudaMemGetInfo(&freeMem, &totalMem);
+cudaError_t err = cudaMemGetInfo(&freeMem, &totalMem);
+if (err != cudaSuccess) {
+ throw std::runtime_error("Failed to get CUDA memory info: " +
+ std::string(cudaGetErrorString(err)));
+}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
std::size_t freeMem, totalMem; | |
cudaMemGetInfo(&freeMem, &totalMem); | |
ACTS_DEBUG("Used CUDA memory after TensorRT initialization: " | |
<< (totalMem - freeMem) * 1e-9 << " / " << totalMem * 1e-9 | |
<< " GB"); | |
std::size_t freeMem, totalMem; | |
cudaError_t err = cudaMemGetInfo(&freeMem, &totalMem); | |
if (err != cudaSuccess) { | |
throw std::runtime_error("Failed to get CUDA memory info: " + | |
std::string(cudaGetErrorString(err))); | |
} | |
ACTS_DEBUG("Used CUDA memory after TensorRT initialization: " | |
<< (totalMem - freeMem) * 1e-9 << " / " << totalMem * 1e-9 | |
<< " GB"); |
while (context == nullptr) { | ||
std::lock_guard<std::mutex> lock(m_contextMutex); | ||
if (!m_contexts.empty()) { | ||
context = std::move(m_contexts.back()); | ||
m_contexts.pop_back(); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prevent deadlock in context acquisition, we must.
Infinite loop without timeout or condition variable, dangerous it is. Add timeout or condition variable to prevent deadlock, we should.
+ auto start = std::chrono::steady_clock::now();
while (context == nullptr) {
std::lock_guard<std::mutex> lock(m_contextMutex);
if (!m_contexts.empty()) {
context = std::move(m_contexts.back());
m_contexts.pop_back();
+ } else {
+ auto now = std::chrono::steady_clock::now();
+ auto elapsed = std::chrono::duration_cast<std::chrono::seconds>
+ (now - start).count();
+ if (elapsed > 60) { // 1 minute timeout
+ throw std::runtime_error("Timeout waiting for available context");
+ }
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
while (context == nullptr) { | |
std::lock_guard<std::mutex> lock(m_contextMutex); | |
if (!m_contexts.empty()) { | |
context = std::move(m_contexts.back()); | |
m_contexts.pop_back(); | |
} | |
} | |
auto start = std::chrono::steady_clock::now(); | |
while (context == nullptr) { | |
std::lock_guard<std::mutex> lock(m_contextMutex); | |
if (!m_contexts.empty()) { | |
context = std::move(m_contexts.back()); | |
m_contexts.pop_back(); | |
} else { | |
auto now = std::chrono::steady_clock::now(); | |
auto elapsed = std::chrono::duration_cast<std::chrono::seconds> | |
(now - start).count(); | |
if (elapsed > 60) { // 1 minute timeout | |
throw std::runtime_error("Timeout waiting for available context"); | |
} | |
std::this_thread::sleep_for(std::chrono::milliseconds(100)); | |
} | |
} |
auto nodeFeatures = | ||
std::any_cast<torch::Tensor>(inNodeFeatures).to(execContext.device); | ||
|
||
auto edgeIndex = | ||
std::any_cast<torch::Tensor>(inEdgeIndex).to(execContext.device); | ||
ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeIndex}); | ||
|
||
auto edgeFeatures = | ||
std::any_cast<torch::Tensor>(inEdgeFeatures).to(execContext.device); | ||
ACTS_DEBUG("edgeFeatures: " << detail::TensorDetails{edgeFeatures}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Validate input tensors before device transfer, we must.
Check tensor validity and dimensions before device transfer, essential it is.
+ if (!nodeFeatures.defined() || nodeFeatures.dim() != 2) {
+ throw std::runtime_error("Invalid node features tensor");
+ }
auto nodeFeatures =
std::any_cast<torch::Tensor>(inNodeFeatures).to(execContext.device);
+ if (!edgeIndex.defined() || edgeIndex.dim() != 2 ||
+ edgeIndex.size(0) != 2) {
+ throw std::runtime_error("Invalid edge index tensor");
+ }
auto edgeIndex =
std::any_cast<torch::Tensor>(inEdgeIndex).to(execContext.device);
+ if (!edgeFeatures.defined() || edgeFeatures.dim() != 2) {
+ throw std::runtime_error("Invalid edge features tensor");
+ }
auto edgeFeatures =
std::any_cast<torch::Tensor>(inEdgeFeatures).to(execContext.device);
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
auto nodeFeatures = | |
std::any_cast<torch::Tensor>(inNodeFeatures).to(execContext.device); | |
auto edgeIndex = | |
std::any_cast<torch::Tensor>(inEdgeIndex).to(execContext.device); | |
ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeIndex}); | |
auto edgeFeatures = | |
std::any_cast<torch::Tensor>(inEdgeFeatures).to(execContext.device); | |
ACTS_DEBUG("edgeFeatures: " << detail::TensorDetails{edgeFeatures}); | |
if (!nodeFeatures.defined() || nodeFeatures.dim() != 2) { | |
throw std::runtime_error("Invalid node features tensor"); | |
} | |
auto nodeFeatures = | |
std::any_cast<torch::Tensor>(inNodeFeatures).to(execContext.device); | |
if (!edgeIndex.defined() || edgeIndex.dim() != 2 || | |
edgeIndex.size(0) != 2) { | |
throw std::runtime_error("Invalid edge index tensor"); | |
} | |
auto edgeIndex = | |
std::any_cast<torch::Tensor>(inEdgeIndex).to(execContext.device); | |
ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeIndex}); | |
if (!edgeFeatures.defined() || edgeFeatures.dim() != 2) { | |
throw std::runtime_error("Invalid edge features tensor"); | |
} | |
auto edgeFeatures = | |
std::any_cast<torch::Tensor>(inEdgeFeatures).to(execContext.device); | |
ACTS_DEBUG("edgeFeatures: " << detail::TensorDetails{edgeFeatures}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
Plugins/ExaTrkX/src/TensorRTEdgeClassifier.cpp (1)
197-203
: Structured performance monitoring system, implement we should.Instead of scattered timing measurements, a dedicated performance monitoring system, beneficial it would be. Consider these improvements:
- Structured timing data collection
- Memory usage tracking at key points
- Performance metrics aggregation
Example implementation structure:
struct PerformanceMetrics { double deviceTransferTime; double inferenceTime; double postProcessingTime; size_t peakMemoryUsage; // Add more metrics as needed }; // Add to class members std::vector<PerformanceMetrics> m_metrics;
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
Plugins/ExaTrkX/src/TensorRTEdgeClassifier.cpp
(1 hunks)
🧰 Additional context used
📓 Learnings (1)
Plugins/ExaTrkX/src/TensorRTEdgeClassifier.cpp (1)
Learnt from: benjaminhuth
PR: acts-project/acts#4016
File: Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp:44-46
Timestamp: 2025-01-21T10:16:57.820Z
Learning: In Acts project, when a derived class implements a virtual method, the documentation can be found in the base class and doesn't need to be repeated. For example, TensorRTEdgeClassifier's operator() documentation is in EdgeClassificationBase.
⏰ Context from checks skipped due to timeout of 90000ms (3)
- GitHub Check: merge-sentinel
- GitHub Check: CI Bridge / build_gnn_tensorrt
- GitHub Check: macos
🔇 Additional comments (4)
Plugins/ExaTrkX/src/TensorRTEdgeClassifier.cpp (4)
30-57
: Well implemented, the logger class is!Clean and proper mapping of TensorRT severity levels to Acts logging levels, I see. RAII principles with
unique_ptr
, you follow.
104-108
: Check cudaMemGetInfo status, you must.Handle errors from CUDA memory information retrieval, we should.
124-133
: Validate input tensors before device transfer, we must.Check tensor validity and dimensions before device transfer, essential it is.
139-145
: Prevent deadlock in context acquisition, we must.Infinite loop without timeout or condition variable, dangerous it is.
auto stream = execContext.stream.value().stream(); | ||
auto status = context->enqueueV3(stream); | ||
if (!status) { | ||
throw std::runtime_error("Failed to execute TensorRT model"); | ||
} | ||
ACTS_CUDA_CHECK(cudaStreamSynchronize(stream)); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Ensure proper cleanup after inference failure, you must.
If inference fails, CUDA stream and resources must be properly cleaned up, they should be.
Apply this diff to ensure proper cleanup:
auto stream = execContext.stream.value().stream();
auto status = context->enqueueV3(stream);
if (!status) {
+ {
+ std::lock_guard<std::mutex> lock(m_contextMutex);
+ m_contexts.push_back(std::move(context));
+ }
throw std::runtime_error("Failed to execute TensorRT model");
}
ACTS_CUDA_CHECK(cudaStreamSynchronize(stream));
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
auto stream = execContext.stream.value().stream(); | |
auto status = context->enqueueV3(stream); | |
if (!status) { | |
throw std::runtime_error("Failed to execute TensorRT model"); | |
} | |
ACTS_CUDA_CHECK(cudaStreamSynchronize(stream)); | |
auto stream = execContext.stream.value().stream(); | |
auto status = context->enqueueV3(stream); | |
if (!status) { | |
{ | |
std::lock_guard<std::mutex> lock(m_contextMutex); | |
m_contexts.push_back(std::move(context)); | |
} | |
throw std::runtime_error("Failed to execute TensorRT model"); | |
} | |
ACTS_CUDA_CHECK(cudaStreamSynchronize(stream)); |
Cannot be compiled currently in the CI
--- END COMMIT MESSAGE ---
Any further description goes here, @-mentions are ok here!
feat
,fix
,refactor
,docs
,chore
andbuild
types.Summary by CodeRabbit
Release Notes
New Features
Infrastructure
Technical Improvements