PyTorch 2.6, CUDA 12.6 TensorRT 10.7, Python 3.12
Torch-TensorRT 2.6.0 targets PyTorch 2.6, TensorRT 10.7, and CUDA 12.6, (builds for CUDA 11.8/12.4 are available via the PyTorch package index - https://download.pytorch.org/whl/cu118 https://download.pytorch.org/whl/cu124). Python versions from 3.9-3.12 are supported. We do not support 3.13 in this release due to TensorRT not supporting that version of Python at this time.
Deprecation notice
The torchscript frontend will be deprecated in v2.6. Specifically, the following usage will no longer be supported and will issue a deprecation warning at runtime if used:
torch_tensorrt.compile(model, ir="torchscript")
Moving forward, we encourage users to transition to one of the supported options:
torch_tensorrt.compile(model)
torch_tensorrt.compile(model, ir="dynamo")
torch.compile(model, backend="tensorrt")
Torchscript will continued to be supported as a deployment format via post compilation tracing
dynamo_model = torch_tensorrt.compile(model, ir="dynamo", arg_inputs=[...])
ts_model = torch.jit.trace(dynamo_model, inputs=[...])
ts_model(...)
Please refer to the README for more information regarding our deprecation policy.
Cross-OS Compilation
In Torch-TensorRT 2.6 it is now possible to use a Linux host to compile Torch-TensorRT programs for Windows using the torch_tensorrt.cross_compile_for_windows
API. These programs use a slightly different serialization format to facilitate this workflow and cannot be run on Linux. Therefore, when calling torch_tensorrt.cross_compile_for_windows
expect the program to be saved directly to disk. Developers should then use the torch_tensorrt.load_cross_compiled_exported_program
on the Windows target to load the serialized program. Torch-TensorRT programs now include target platform information to verify OS compatibility on deserialization. This in turn has caused an ABI bump for the runtime.
if load:
# load the saved model in Windows
if platform.system() != "Windows" or platform.machine() != "AMD64":
raise ValueError(
"cross runtime compiled model for windows can only be loaded in Windows system"
)
loaded_model = torchtrt.load_cross_compiled_exported_program(save_path).module()
print(f"model has been successfully loaded from ${save_path}")
# inference
trt_output = loaded_model(input)
print(f"inference result: {trt_output}")
else:
if platform.system() != "Linux" or platform.architecture()[0] != "64bit":
raise ValueError(
"cross runtime compiled model for windows can only be compiled in Linux system"
)
compile_spec = {
"debug": True,
"min_block_size": 1,
}
torchtrt.cross_compile_for_windows(
model, file_path=save_path, inputs=inputs, **compile_spec
)
print(
f"model has been successfully cross compiled and saved in Linux to {args.path}"
)
Runtime Weight Streaming
Weight Streaming in Torch-TensorRT is a memory optimization technique that helps deploy large models on memory-constrained devices by dynamically loading weights as needed during inference, reducing the overall memory footprint and enabling more efficient use of hardware resources. It is an opt-in feature that needs to be enabled at both build time and runtime.
trt_model = torch_tensorrt.dynamo.compile(
model,
inputs=input_tensors,
enabled_precisions={torch.float32}, # only float32 precision is allowed for strongly typed network
use_explicit_typing=True, # create a strongly typed network
enable_weight_streaming=True, # enable weight streaming
)
Control the weight streaming budget at runtime using the weight streaming context manager
with torch_tensorrt.runtime.weight_streaming(trt_model) as weight_streaming_ctx:
# Get the total size of streamable weights in the engine
streamable_budget = weight_streaming_ctx.total_device_budget
# Set 50% weight streaming budget
requested_budget = int(streamable_budget * 0.5)
weight_streaming_ctx.device_budget = requested_budget
trt_model(inputs)
Intra-Block CUDAGraphs
We updated CUDAGraphs API to support Intra-Block CUDAGraphs. When a compiled Torch-TensorRT module has graph breaks, previously, only TensorRT blocks could be run with CUDAGraph's optimized kernel launch. With Torch-TensorRT 2.6 the entire graph can be captured and executed in a unified CUDAGraph to minimize kernel launch overhead.
# Previous API
with torch_tensorrt.runtime.enable_cudagraphs():
torchtrt_model(inputs)
# New API
with torch_tensorrt.runtime.enable_cudagraphs(torchtrt_model) as cudagraphs_model:
cudagraphs_model(input)
Improvements to Engine Caching
First, there are some API changes.
make_refittable
was renamed toimmutable_weights
in preparation for a future release that will default engines to be compiled with the refit feature enabled, allowing for the Torch-TensorRT engine cache to provide maximum benefits.refit_identical_engine_weights
was added to specify whether to refit the engine with identical weights;strip_engine_weights
was added to specify whether to strip the engine weights.- The default disk size for engine caching was expanded to 5GB.
In addition, one of the capabilities of engine caching is to recognize whether two graphs are isomorphic. If a new graph is isomorphic to any previously compiled TensorRT engine, the engine cache will reuse that engine instead of recompiling the graph, thereby avoiding recompilation time. In the previous release, we utilized FxGraphCachePickler.get_hash(new_gm)
from PyTorch to calculate hash values which took up a large portion of the total compile time. In this release, we designed a new hash function to get hash values quickly and then determine the isomorphism with ~4x speedup.
C++11 ABI Changes
To keep pace with PyTorch, as of release 2.6, we switched docker images from manylinux to manylinux2_28. In Torch/Torch-TensorRT 2.6, PRE_CXX11_ABI
is used for CUDA 11.8 and 12.4, while CXX11_ABI
is used for CUDA 12.6. For Torch/Torch-TensorRT 2.7, CXX11_ABI
will be used for all CUDA 11.8, 12.4, and 12.6.
Explicit Typing
We introduce a new compilation setting, use_explicit_typing
, to enable mixed precision inference with Torch-TensorRT. When this flag is enabled, TensorRT operates in strong typing mode, ensuring that layer data types are preserved during compilation. For a detailed demonstration of this behavior, refer to the provided tutorial. To learn more about strong typing in TensorRT, refer to the relevant section in the TensorRT Developer Guide.
Model Zoo
- We have added Segment Anything Model 2 compilation using Torch-TensorRT (SAM2) to our model zoo. The example can be found here
- We have also added a torch.compile example for GPT2 using the
tensorrt
backend. This example demonstrates the use of the HuggingFacegenerate
API for auto-regressive decoding. For export based workflow (ir=dynamo
), we provide a custom generate function to handle output decoding.
Multi-GPU Improvements
There are experimental improvements to multi-gpu workflows, including pulling NCCL operations into TensorRT subgraphs automatically. These should be considered alpha stability. More information can be found here: https://github.com/pytorch/TensorRT/tree/main/examples/distributed_inference
What's Changed
- upgrade modelopt by @lanluo-nvidia in #3160
- feat: exclude refit sensitive ops from TRT compilation by @peri044 in #3159
- tool: Adding support for the uv system by @narendasan in #3125
- upgrade torch from 2.5.0.dev to 2.6.0.dev in main branch by @lanluo-nvidia in #3165
- fix: Fix static arange export by @peri044 in #3194
- docs: A tutorial on how to overload converters in Torch-TensorRT by @narendasan in #3197
- Adjust cpp torch trt logging level with compiler option by @keehyuna in #3181
- extend the timeout-minutes in build/test from 60 min to 120 min by @lanluo-nvidia in #3203
- extend windows build from 60 min to 120 min by @lanluo-nvidia in #3218
- fix the global partitioner bug by @lanluo-nvidia in #3195
- feat: Implement FP32 accumulation for matmul by @peri044 in #3110
- chore: Make substitute-runner in Windows CI work again by @HolyWu in #3225
- Run test_base_fp8 for compute capability 8.9 or later by @HolyWu in #3164
- Fixed batchnorm bug by @cehongwang in #3170
- Fix for warning as default stream was used in enqueueV3 by @keehyuna in #3191
- chore: doc updates by @peri044 in #3238
- chore: Additional Doc fixes by @peri044 in #3243
- docs: escape dash to avoid collapsing -- to - by @dgcnz in #3235
- feat: log_softmax decomposition by @HolyWu in #3137
- fix: change floordiv to divmod for
//tests/core/lowering:test_remove_unnecessary_casts
by @zewenli98 in #3223 - Add support for JetPack 6.1 build by @lanluo-nvidia in #3211
- Require full compilation arg by @apbose in #3193
- Fix code example in README.md by @juliusgh in #3253
- chore: Access user settings within the lowering system by @peri044 in #3245
- fix: expand dim for scalar numpy when freezing tensors to IConstantLayers by @chohk88 in #3251
- chore: Adapt CIA ops decomposition handling in upsample converters to torch 2.6 by @HolyWu in #3227
- feat: Support weight streaming by @keehyuna in #3111
- fix issue 3259 by @lanluo-nvidia in #3260
- skip dummy inference and run_shape_analysis by @lanluo-nvidia in #3212
- Remove numpy version constraint in test requirements by @HolyWu in #3264
- fix issue#3269: unwrap tensor shape without opt val by @lanluo-nvidia in #3279
- disable python 3.13 for linux by @lanluo-nvidia in #3271
- switch from fx.symbolic_trace to dynamo_trace for converter test part-1 by @lanluo-nvidia in #3261
- cross compile for windows by @lanluo-nvidia in #3220
- chore: add source_ir in slice layer name by @jiwoong-choi in #3284
- fix MutableTorchTensorRTModule load issue by @lanluo-nvidia in #3281
- don't initialize cuda at import time by @technillogue in #3244
- change decomposition default table due to upstream torch change by @lanluo-nvidia in #3291
- feat: Support exporting Torch-TRT compiled Graphmodules by @peri044 in #3262
- Add tensorrt test workflow by @lanluo-nvidia in #3266
- test future tensorrt version in windows wf by @lanluo-nvidia in #3290
- fix: get_hash function for engine caching by @zewenli98 in #3293
- feat: InstanceNorm decomposition by @HolyWu in #3288
- MODULE.bazel file for NGC docker container by @apbose in #3156
- docs: Updated tutorial for triton + torch-tensorrt by @narendasan in #3292
- Fix LayerNorm fp16 precision by @HolyWu in #3272
- Get decompositions only for CIA ops by @HolyWu in #3297
- fix: cumsum add_constant bug fix (add dtype for np zeros) by @chohk88 in #3258
- fix: change docker img from manylinux to manylinux2_28 for all CUDA versions by @zewenli98 in #3312
- fix: export USE_CXX11_ABI=1 for cuda12.6 by @zewenli98 in #3319
- feat: Support weight-stripped engine and REFIT_IDENTICAL flag by @zewenli98 in #3167
- fix: Fix additional mem copy of the model during re-export by @peri044 in #3302
- fix: Fix copying metadata during lowering by @peri044 in #3320
- upgrade tensorrt dependency to >=10.3.0,<=10.6.0 by @lanluo-nvidia in #3286
- Use INormalizationLayer for GroupNorm by @HolyWu in #3273
- Torch TRT ngc container changes by @apbose in #3299
- feat: Add SAM2 to our model zoo by @peri044 in #3318
- fix: Remove pytorch overhead while finding fusions for fully convertible models by @peri044 in #3311
- feat: Automatically generating converters for QDP plugins by @narendasan in #3321
- Fix failed TestNativeLayerNormConverter by @HolyWu in #3315
- fix: Correct mutex scope in execute_engine() by @keehyuna in #3310
- Replace scaled_dot_product_attention lowering pass with decomposition by @HolyWu in #3296
- chore: example fixes by @peri044 in #3176
- Remove linear lowering pass and converter by @HolyWu in #3323
- fix: Fix meta kernel ops import issue for Python only builds by @peri044 in #3322
- Add test case for ITensor weight in convolution and fix related bug by @chohk88 in #3327
- feat: Runtime output buffer optimization by @keehyuna in #3276
- full_like to full decomposition moving to decomposition.py for dynami… by @apbose in #3289
- Wrapper module around TRT + pytorch subgraphs by @keehyuna in #3270
- feat: add args for profiling engine caching by @zewenli98 in #3329
- fix: update bazelisk to fix build errors by @peri044 in #3328
- chore: adding additional logging to the converter registry system by @narendasan in #3199
- chore: revert attention decomposition due to flux bug by @peri044 in #3332
- fix: Fix null inputs case by @peri044 in #3334
- fix: Record cudagraphs when weight streaming budget has changed by @keehyuna in #3309
- Cherrypick: Bump TRT version to 10.7 by @zewenli98 in #3341
- Cherrypick: nccl ops multi gpu by @apbose in #3342
- Update build-test-linux.yml by @narendasan in #3345
- fix: CI errors on release 2.6 by @zewenli98 in #3358
- fix: CI docker build error for release 2.6 by @zewenli98 in #3360
- [cherry-pick] trtp for 2.6 release by @narendasan in #3372
New Contributors
- @dgcnz made their first contribution in #3235
- @technillogue made their first contribution in #3244
Full Changelog: v2.5.0...v2.6.0