Skip to content

Torch-TensorRT v2.6.0

Latest
Compare
Choose a tag to compare
@narendasan narendasan released this 05 Feb 22:03
44375f2

PyTorch 2.6, CUDA 12.6 TensorRT 10.7, Python 3.12

Torch-TensorRT 2.6.0 targets PyTorch 2.6, TensorRT 10.7, and CUDA 12.6, (builds for CUDA 11.8/12.4 are available via the PyTorch package index - https://download.pytorch.org/whl/cu118 https://download.pytorch.org/whl/cu124). Python versions from 3.9-3.12 are supported. We do not support 3.13 in this release due to TensorRT not supporting that version of Python at this time.

Deprecation notice

The torchscript frontend will be deprecated in v2.6. Specifically, the following usage will no longer be supported and will issue a deprecation warning at runtime if used:

torch_tensorrt.compile(model, ir="torchscript")

Moving forward, we encourage users to transition to one of the supported options:

torch_tensorrt.compile(model)
torch_tensorrt.compile(model, ir="dynamo")
torch.compile(model, backend="tensorrt")

Torchscript will continued to be supported as a deployment format via post compilation tracing

dynamo_model = torch_tensorrt.compile(model, ir="dynamo", arg_inputs=[...])
ts_model = torch.jit.trace(dynamo_model, inputs=[...])
ts_model(...)

Please refer to the README for more information regarding our deprecation policy.

Cross-OS Compilation

In Torch-TensorRT 2.6 it is now possible to use a Linux host to compile Torch-TensorRT programs for Windows using the torch_tensorrt.cross_compile_for_windows API. These programs use a slightly different serialization format to facilitate this workflow and cannot be run on Linux. Therefore, when calling torch_tensorrt.cross_compile_for_windows expect the program to be saved directly to disk. Developers should then use the torch_tensorrt.load_cross_compiled_exported_program on the Windows target to load the serialized program. Torch-TensorRT programs now include target platform information to verify OS compatibility on deserialization. This in turn has caused an ABI bump for the runtime.

if load:
    # load the saved model in Windows
    if platform.system() != "Windows" or platform.machine() != "AMD64":
        raise ValueError(
            "cross runtime compiled model for windows can only be loaded in Windows system"
        )
    loaded_model = torchtrt.load_cross_compiled_exported_program(save_path).module()
    print(f"model has been successfully loaded from ${save_path}")
    # inference
    trt_output = loaded_model(input)
    print(f"inference result: {trt_output}")
else:
    if platform.system() != "Linux" or platform.architecture()[0] != "64bit":
        raise ValueError(
            "cross runtime compiled model for windows can only be compiled in Linux system"
        )
    compile_spec = {
        "debug": True,
        "min_block_size": 1,
    }
    torchtrt.cross_compile_for_windows(
        model, file_path=save_path, inputs=inputs, **compile_spec
    )
    print(
        f"model has been successfully cross compiled and saved in Linux to {args.path}"
    )

Runtime Weight Streaming

Weight Streaming in Torch-TensorRT is a memory optimization technique that helps deploy large models on memory-constrained devices by dynamically loading weights as needed during inference, reducing the overall memory footprint and enabling more efficient use of hardware resources. It is an opt-in feature that needs to be enabled at both build time and runtime.

trt_model = torch_tensorrt.dynamo.compile(
    model,
    inputs=input_tensors,
    enabled_precisions={torch.float32}, # only float32 precision is allowed for strongly typed network
    use_explicit_typing=True,           # create a strongly typed network
    enable_weight_streaming=True,       # enable weight streaming
)

Control the weight streaming budget at runtime using the weight streaming context manager

with torch_tensorrt.runtime.weight_streaming(trt_model) as weight_streaming_ctx:
    # Get the total size of streamable weights in the engine
    streamable_budget = weight_streaming_ctx.total_device_budget
    # Set 50% weight streaming budget
    requested_budget = int(streamable_budget * 0.5)
    weight_streaming_ctx.device_budget = requested_budget
    trt_model(inputs)

Intra-Block CUDAGraphs

We updated CUDAGraphs API to support Intra-Block CUDAGraphs. When a compiled Torch-TensorRT module has graph breaks, previously, only TensorRT blocks could be run with CUDAGraph's optimized kernel launch. With Torch-TensorRT 2.6 the entire graph can be captured and executed in a unified CUDAGraph to minimize kernel launch overhead.

# Previous API
with torch_tensorrt.runtime.enable_cudagraphs():
    torchtrt_model(inputs)
# New API
with torch_tensorrt.runtime.enable_cudagraphs(torchtrt_model) as cudagraphs_model:
    cudagraphs_model(input)

Improvements to Engine Caching

First, there are some API changes.

  1. make_refittable was renamed to immutable_weights in preparation for a future release that will default engines to be compiled with the refit feature enabled, allowing for the Torch-TensorRT engine cache to provide maximum benefits.
  2. refit_identical_engine_weights was added to specify whether to refit the engine with identical weights;
  3. strip_engine_weights was added to specify whether to strip the engine weights.
  4. The default disk size for engine caching was expanded to 5GB.

In addition, one of the capabilities of engine caching is to recognize whether two graphs are isomorphic. If a new graph is isomorphic to any previously compiled TensorRT engine, the engine cache will reuse that engine instead of recompiling the graph, thereby avoiding recompilation time. In the previous release, we utilized FxGraphCachePickler.get_hash(new_gm) from PyTorch to calculate hash values which took up a large portion of the total compile time. In this release, we designed a new hash function to get hash values quickly and then determine the isomorphism with ~4x speedup.

C++11 ABI Changes

To keep pace with PyTorch, as of release 2.6, we switched docker images from manylinux to manylinux2_28. In Torch/Torch-TensorRT 2.6, PRE_CXX11_ABI is used for CUDA 11.8 and 12.4, while CXX11_ABI is used for CUDA 12.6. For Torch/Torch-TensorRT 2.7, CXX11_ABI will be used for all CUDA 11.8, 12.4, and 12.6.

Explicit Typing

We introduce a new compilation setting, use_explicit_typing, to enable mixed precision inference with Torch-TensorRT. When this flag is enabled, TensorRT operates in strong typing mode, ensuring that layer data types are preserved during compilation. For a detailed demonstration of this behavior, refer to the provided tutorial. To learn more about strong typing in TensorRT, refer to the relevant section in the TensorRT Developer Guide.

Model Zoo

Multi-GPU Improvements

There are experimental improvements to multi-gpu workflows, including pulling NCCL operations into TensorRT subgraphs automatically. These should be considered alpha stability. More information can be found here: https://github.com/pytorch/TensorRT/tree/main/examples/distributed_inference

What's Changed

New Contributors

Full Changelog: v2.5.0...v2.6.0