Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add TensorRT support for GNNs #4016

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
37 changes: 37 additions & 0 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,43 @@ test_exatrkx_python:
- pytest -rFsv -k torch --collect-only
- pytest -rFsv -k gpu-torch # For now only test torch GPU pipeline

build_gnn_tensorrt:
stage: build
image: ghcr.io/acts-project/ubuntu2404_tensorrt:sha-b4f481f@sha256:8887aa00ad4394a53b4ca54968121d8893d537e5daf50805f1dd2030caef78ce
variables:
DEPENDENCY_URL: https://acts.web.cern.ch/ACTS/ci/ubuntu-24.04/deps.$DEPENDENCY_TAG.tar.zst

cache:
key: ccache-${CI_JOB_NAME}-${CI_COMMIT_REF_SLUG}-${CCACHE_KEY_SUFFIX}
fallback_keys:
- ccache-${CI_JOB_NAME}-${CI_DEFAULT_BRANCH}-${CCACHE_KEY_SUFFIX}
when: always
paths:
- ${CCACHE_DIR}
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved

tags:
- docker-gpu-nvidia

script:
- git clone $CLONE_URL src
- cd src
- git checkout $HEAD_SHA
- source CI/dependencies.sh
- cd ..
- mkdir build
- >
cmake -B build -S src
-DACTS_BUILD_PLUGIN_EXATRKX=ON
-DACTS_EXATRKX_ENABLE_TORCH=OFF
-DACTS_EXATRKX_ENABLE_CUDA=ON
-DACTS_EXATRKX_ENABLE_TENSORRT=ON
-DPython_EXECUTABLE=$(which python3)
-DCMAKE_CUDA_ARCHITECTURES="75;86"
- ccache -z
- cmake --build build -- -j6
- ccache -s


build_linux_ubuntu:
stage: build
image: ghcr.io/acts-project/ubuntu2404:63
Expand Down
7 changes: 0 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,6 @@ if(ACTS_BUILD_PLUGIN_EXATRKX)
else()
message(STATUS "Build Exa.TrkX plugin for CPU only")
endif()
if(NOT (ACTS_EXATRKX_ENABLE_ONNX OR ACTS_EXATRKX_ENABLE_TORCH))
message(
FATAL_ERROR
"When building the Exa.TrkX plugin, at least one of ACTS_EXATRKX_ENABLE_ONNX \
and ACTS_EXATRKX_ENABLE_TORCHSCRIPT must be enabled."
)
endif()
if(ACTS_EXATRKX_ENABLE_TORCH)
find_package(TorchScatter REQUIRED)
endif()
Expand Down
27 changes: 27 additions & 0 deletions Examples/Python/src/ExaTrkXTrackFinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
#include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp"
#include "Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp"
#include "Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp"
#include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp"
#include "Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp"
#include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
Expand Down Expand Up @@ -112,6 +113,32 @@ void addExaTrkXTrackFinding(Context &ctx) {
}
#endif

#ifdef ACTS_EXATRKX_WITH_TENSORRT
{
using Alg = Acts::TensorRTEdgeClassifier;
using Config = Alg::Config;

auto alg =
py::class_<Alg, Acts::EdgeClassificationBase, std::shared_ptr<Alg>>(
mex, "TensorRTEdgeClassifier")
.def(py::init([](const Config &c, Logging::Level lvl) {
return std::make_shared<Alg>(
c, getDefaultLogger("EdgeClassifier", lvl));
}),
py::arg("config"), py::arg("level"))
.def_property_readonly("config", &Alg::config);

auto c = py::class_<Config>(alg, "Config").def(py::init<>());
ACTS_PYTHON_STRUCT_BEGIN(c, Config);
ACTS_PYTHON_MEMBER(modelPath);
ACTS_PYTHON_MEMBER(selectedFeatures);
ACTS_PYTHON_MEMBER(cut);
ACTS_PYTHON_MEMBER(deviceID);
ACTS_PYTHON_MEMBER(doSigmoid);
ACTS_PYTHON_STRUCT_END();
}
#endif

#ifdef ACTS_EXATRKX_ONNX_BACKEND
{
using Alg = Acts::OnnxMetricLearning;
Expand Down
14 changes: 14 additions & 0 deletions Plugins/ExaTrkX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@ if(ACTS_EXATRKX_ENABLE_TORCH)
)
endif()

if(ACTS_EXATRKX_ENABLE_TENSORRT)
find_package(TensorRT REQUIRED)
message(STATUS "Found TensorRT ${TensorRT_VERSION}")
target_link_libraries(
ActsPluginExaTrkX
PUBLIC trt::nvinfer trt::nvinfer_plugin
)
target_sources(ActsPluginExaTrkX PRIVATE src/TensorRTEdgeClassifier.cpp)
target_compile_definitions(
ActsPluginExaTrkX
PUBLIC ACTS_EXATRKX_WITH_TENSORRT
)
endif()

target_include_directories(
ActsPluginExaTrkX
PUBLIC
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// This file is part of the ACTS project.
//
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Plugins/ExaTrkX/Stages.hpp"
#include "Acts/Utilities/Logger.hpp"

#include <memory>

#include <torch/torch.h>

namespace nvinfer1 {
class IRuntime;
class ICudaEngine;
class ILogger;
class IExecutionContext;
} // namespace nvinfer1

namespace Acts {

class TensorRTEdgeClassifier final : public Acts::EdgeClassificationBase {
public:
struct Config {
std::string modelPath;
std::vector<int> selectedFeatures = {};
float cut = 0.21;
int deviceID = 0;
bool useEdgeFeatures = false;
bool doSigmoid = true;
};

TensorRTEdgeClassifier(const Config &cfg,
std::unique_ptr<const Logger> logger);
~TensorRTEdgeClassifier();

std::tuple<std::any, std::any, std::any, std::any> operator()(
std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {},
const ExecutionContext &execContext = {}) override;
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved

Config config() const { return m_cfg; }
torch::Device device() const override { return torch::kCUDA; };
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved

private:
std::unique_ptr<const Acts::Logger> m_logger;
const auto &logger() const { return *m_logger; }

Config m_cfg;

std::unique_ptr<nvinfer1::IRuntime> m_runtime;
std::unique_ptr<nvinfer1::ICudaEngine> m_engine;
std::unique_ptr<nvinfer1::ILogger> m_trtLogger;
std::unique_ptr<nvinfer1::IExecutionContext> m_context;
};

} // namespace Acts
190 changes: 190 additions & 0 deletions Plugins/ExaTrkX/src/TensorRTEdgeClassifier.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// This file is part of the ACTS project.
//
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

#include "Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp"

#include "Acts/Plugins/ExaTrkX/detail/Utils.hpp"

#include <chrono>
#include <filesystem>
#include <fstream>

#include <NvInfer.h>
#include <NvInferPlugin.h>
#include <NvInferRuntimeBase.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>

#include "printCudaMemInfo.hpp"

using namespace torch::indexing;

namespace {

class TensorRTLogger : public nvinfer1::ILogger {
std::unique_ptr<const Acts::Logger> m_logger;

public:
TensorRTLogger(Acts::Logging::Level lvl)
: m_logger(Acts::getDefaultLogger("TensorRT", lvl)) {}

void log(Severity severity, const char *msg) noexcept override {
const auto &logger = *m_logger;
switch (severity) {
case Severity::kVERBOSE:
ACTS_DEBUG(msg);
break;
case Severity::kINFO:
ACTS_INFO(msg);
break;
case Severity::kWARNING:
ACTS_WARNING(msg);
break;
case Severity::kERROR:
ACTS_ERROR(msg);
break;
case Severity::kINTERNAL_ERROR:
ACTS_FATAL(msg);
break;
}
}
};

} // namespace

namespace Acts {

TensorRTEdgeClassifier::TensorRTEdgeClassifier(
const Config &cfg, std::unique_ptr<const Logger> _logger)
: m_logger(std::move(_logger)),
m_cfg(cfg),
m_trtLogger(std::make_unique<TensorRTLogger>(m_logger->level())) {
auto status = initLibNvInferPlugins(m_trtLogger.get(), "");
assert(status);
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved

std::size_t fsize =
std::filesystem::file_size(std::filesystem::path(m_cfg.modelPath));
std::vector<char> engineData(fsize);

ACTS_DEBUG("Load '" << m_cfg.modelPath << "' with size " << fsize);

std::ifstream engineFile(m_cfg.modelPath);
engineFile.read(engineData.data(), fsize);

benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved
m_runtime.reset(nvinfer1::createInferRuntime(*m_trtLogger));

m_engine.reset(m_runtime->deserializeCudaEngine(engineData.data(), fsize));

m_context.reset(m_engine->createExecutionContext());
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved
}

TensorRTEdgeClassifier::~TensorRTEdgeClassifier() {}

auto milliseconds = [](const auto &a, const auto &b) {
return std::chrono::duration<double, std::milli>(b - a).count();
};

struct TimePrinter {
const char *name;
decltype(std::chrono::high_resolution_clock::now()) t0, t1;
TimePrinter(const char *n) : name(n) {
t0 = std::chrono::high_resolution_clock::now();
}
~TimePrinter() {
std::cout << name << ": " << milliseconds(t0, t1) << std::endl;
}
};

#if 0
#define TIME_BEGIN(name) TimePrinter printer##name(#name);
#define TIME_END(name) \
printer##name.t1 = std::chrono::high_resolution_clock::now();
#else
#define TIME_BEGIN(name) /*nothing*/
#define TIME_END(name) /*ǹothing*/
#endif

std::tuple<std::any, std::any, std::any, std::any>
TensorRTEdgeClassifier::operator()(std::any inNodeFeatures,
std::any inEdgeIndex,
std::any inEdgeFeatures,
const ExecutionContext &execContext) {
decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4, t5;
t0 = std::chrono::high_resolution_clock::now();

c10::cuda::CUDAStreamGuard(execContext.stream.value());

auto nodeFeatures =
std::any_cast<torch::Tensor>(inNodeFeatures).to(torch::kCUDA);

auto edgeIndex = std::any_cast<torch::Tensor>(inEdgeIndex).to(torch::kCUDA);
ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeIndex});

auto edgeFeatures =
std::any_cast<torch::Tensor>(inEdgeFeatures).to(torch::kCUDA);
ACTS_DEBUG("edgeFeatures: " << detail::TensorDetails{edgeFeatures});

t1 = std::chrono::high_resolution_clock::now();

m_context->setInputShape(
"x", nvinfer1::Dims2{nodeFeatures.size(0), nodeFeatures.size(1)});
m_context->setTensorAddress("x", nodeFeatures.data_ptr());

m_context->setInputShape(
"edge_index", nvinfer1::Dims2{edgeIndex.size(0), edgeIndex.size(1)});
m_context->setTensorAddress("edge_index", edgeIndex.data_ptr());

m_context->setInputShape(
"edge_attr", nvinfer1::Dims2{edgeFeatures.size(0), edgeFeatures.size(1)});
m_context->setTensorAddress("edge_attr", edgeFeatures.data_ptr());

void *outputMem{nullptr};
std::size_t outputSize = edgeIndex.size(1) * sizeof(float);
cudaMalloc(&outputMem, outputSize);
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved
m_context->setTensorAddress("output", outputMem);

t2 = std::chrono::high_resolution_clock::now();

{
auto stream = execContext.stream.value().stream();
auto status = m_context->enqueueV3(stream);
cudaStreamSynchronize(stream);
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved
ACTS_VERBOSE("TensorRT output status: " << std::boolalpha << status);
}

t3 = std::chrono::high_resolution_clock::now();

auto scores = torch::from_blob(
outputMem, edgeIndex.size(1), 1, [](void *ptr) { cudaFree(ptr); },
torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));

scores.sigmoid_();

ACTS_VERBOSE("Size after classifier: " << scores.size(0));
ACTS_VERBOSE("Slice of classified output:\n"
<< scores.slice(/*dim=*/0, /*start=*/0, /*end=*/9));
printCudaMemInfo(logger());

torch::Tensor mask = scores > m_cfg.cut;
torch::Tensor edgesAfterCut = edgeIndex.index({Slice(), mask});

scores = scores.masked_select(mask);
ACTS_VERBOSE("Size after score cut: " << edgesAfterCut.size(1));
printCudaMemInfo(logger());

t4 = std::chrono::high_resolution_clock::now();

ACTS_DEBUG("Time anycast: " << milliseconds(t0, t1));
ACTS_DEBUG("Time alloc, set shape " << milliseconds(t1, t2));
ACTS_DEBUG("Time inference: " << milliseconds(t2, t3));
ACTS_DEBUG("Time sigmoid and cut: " << milliseconds(t3, t4));

return {nodeFeatures, edgesAfterCut, edgeFeatures, scores};
}

} // namespace Acts
Loading
Loading