Skip to content

Commit

Permalink
docs: update docs for new api
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jul 22, 2021
1 parent 265f71e commit 90c28f7
Show file tree
Hide file tree
Showing 13 changed files with 584 additions and 578 deletions.
31 changes: 15 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ More Information / System Architecture:
#include "trtorch/trtorch.h"

...
auto compile_settings = trtorch::CompileSpec(dims);
// FP16 execution
compile_settings.op_precision = torch::kHalf;
// Set input datatypes. Allowerd options torch::{kFloat, kHalf, kChar, kInt32, kBool}
// Size of input_dtypes should match number of inputs to the network.
// If input_dtypes is not set, default precision for input tensors would be float32
compile_spec.input_dtypes = {torch::kHalf};
// If input_dtypes is not set, default precision follows traditional PyT / TRT rules
auto input = trtorch::CompileSpec::Input(dims, torch::kHalf)
auto compile_settings = trtorch::CompileSpec({input});
// FP16 execution
compile_settings.enabled_precisions = {torch::kHalf};
// Compile module
auto trt_mod = trtorch::CompileGraph(ts_mod, compile_settings);
// Run like normal
Expand All @@ -40,15 +40,14 @@ import trtorch
...
compile_settings = {
"input_shapes": [
{
"min": [1, 3, 224, 224],
"opt": [1, 3, 512, 512],
"max": [1, 3, 1024, 1024]
}, # For static size [1, 3, 224, 224]
],
"op_precision": torch.half, # Run with FP16
"input_dtypes": [torch.half] # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
"inputs": [trtorch.Input(
min_shape=[1, 3, 224, 224],
opt_shape=[1, 3, 512, 512],
max_shape=[1, 3, 1024, 1024]
# For static size shape=[1, 3, 224, 224]
dtype=torch.half, # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
)],
"enabled_precision": {torch.half}, # Run with FP16
}
trt_ts_module = trtorch.compile(torch_script_module, compile_settings)
Expand All @@ -59,9 +58,9 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
```

> Notes on running in lower precisions:
> - Set precision with compile_spec.op_precision
> - Enabled lower precisions with compile_spec.enabled_precisions
> - The module should be left in FP32 before compilation (FP16 can support half tensor models)
> - In FP16 only input tensors should be converted to FP16, other precisions use FP32
> - In FP16 only input tensors by default should be FP16, other precisions use FP32. This can be overrided by setting Input::dtype
## Platform Support

Expand Down
18 changes: 9 additions & 9 deletions docsrc/tutorials/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,14 @@ to load in a deployment application. In order to load a TensorRT/TorchScript mod
script_model.eval() # torch module needs to be in eval (not training) mode
compile_settings = {
"input_shapes": [
{
"min": [1, 1, 16, 16],
"opt": [1, 1, 32, 32],
"max": [1, 1, 64, 64]
},
"inputs": [trtorch.Input(
min_shape=[1, 1, 16, 16],
opt_shape=[1, 1, 32, 32],
max_shape=[1, 1, 64, 64]
dtype=torch.half,
),
],
"op_precision": torch.half # Run with fp16
"enable_precisions": {torch.float, torch.half} # Run with fp16
}
trt_ts_module = trtorch.compile(script_model, compile_settings)
Expand Down Expand Up @@ -324,7 +324,7 @@ We can also set settings like operating precision to run in FP16.
auto in = torch::randn({1, 1, 32, 32}, {torch::kCUDA}).to(torch::kHALF);
auto input_sizes = std::vector<trtorch::CompileSpec::InputRange>({in.sizes()});
trtorch::CompileSpec info(input_sizes);
info.op_precision = torch::kHALF;
info.enable_precisions.insert(torch::kHALF);
auto trt_mod = trtorch::CompileGraph(mod, info);
auto out = trt_mod.forward({in});

Expand Down Expand Up @@ -372,7 +372,7 @@ If you want to save the engine produced by TRTorch to use in a TensorRT applicat
auto in = torch::randn({1, 1, 32, 32}, {torch::kCUDA}).to(torch::kHALF);
auto input_sizes = std::vector<trtorch::CompileSpec::InputRange>({in.sizes()});
trtorch::CompileSpec info(input_sizes);
info.op_precision = torch::kHALF;
info.enabled_precisions.insert(torch::kHALF);
auto trt_mod = trtorch::ConvertGraphToTRTEngine(mod, "forward", info);
std::ofstream out("/tmp/engine_converted_from_jit.trt");
out << engine;
Expand Down
11 changes: 6 additions & 5 deletions docsrc/tutorials/ptq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ Then all thats required to setup the module for INT8 calibration is to set the f
/// Configure settings for compilation
auto compile_spec = trtorch::CompileSpec({input_shape});
/// Set operating precision to INT8
compile_spec.op_precision = torch::kI8;
compile_spec.enabled_precisions.insert(torch::kF16);
compile_spec.enabled_precisions.insert(torch::kI8);
/// Use the TensorRT Entropy Calibrator
compile_spec.ptq_calibrator = calibrator;
/// Set a larger workspace (you may get better performace from doing so)
Expand Down Expand Up @@ -169,8 +170,8 @@ a TensorRT calibrator by providing desired configuration. The following code dem
device=torch.device('cuda:0'))
compile_spec = {
"input_shapes": [[1, 3, 32, 32]],
"op_precision": torch.int8,
"inputs": [trtorch.Input((1, 3, 32, 32))],
"enabled_precisions": {torch.float, torch.half, torch.int8},
"calibrator": calibrator,
"device": {
"device_type": trtorch.DeviceType.GPU,
Expand All @@ -190,8 +191,8 @@ to use ``CacheCalibrator`` to use in INT8 mode.
calibrator = trtorch.ptq.CacheCalibrator("./calibration.cache")
compile_settings = {
"input_shapes": [[1, 3, 32, 32]],
"op_precision": torch.int8,
"inputs": [trtorch.Input([1, 3, 32, 32])],
"enabled_precisions": {torch.float, torch.half, torch.int8},
"calibrator": calibrator,
"max_batch_size": 32,
}
Expand Down
136 changes: 73 additions & 63 deletions docsrc/tutorials/trtorchc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,75 +17,85 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r
.. code-block:: txt
trtorchc [input_file_path] [output_file_path]
[input_shapes...] {OPTIONS}
[input_specs...] {OPTIONS}
TRTorch is a compiler for TorchScript, it will compile and optimize
TorchScript programs to run on NVIDIA GPUs using TensorRT
TRTorch is a compiler for TorchScript, it will compile and optimize
TorchScript programs to run on NVIDIA GPUs using TensorRT
OPTIONS:
OPTIONS:
-h, --help Display this help menu
Verbiosity of the compiler
-v, --verbose Dumps debugging information about the
compilation process onto the console
-w, --warnings Disables warnings generated during
compilation onto the console (warnings
are on by default)
--info Dumps info messages generated during
compilation onto the console
--build-debuggable-engine Creates a debuggable engine
--use-strict-types Restrict operating type to only use set
default operation precision
(op_precision)
--allow-gpu-fallback (Only used when targeting DLA
(device-type)) Lets engine run layers on
GPU if they are not supported on DLA
-p[precision],
--default-op-precision=[precision]
Default operating precision for the
engine (Int8 requires a
calibration-cache argument) [ float |
float32 | f32 | half | float16 | f16 |
int8 | i8 ] (default: float)
-d[type], --device-type=[type] The type of device the engine should be
built for [ gpu | dla ] (default: gpu)
--engine-capability=[capability] The type of device the engine should be
built for [ default | safe_gpu |
safe_dla ]
--calibration-cache-file=[file_path]
Path to calibration cache file to use
for post training quantization
--num-min-timing-iter=[num_iters] Number of minimization timing iterations
used to select kernels
--num-avg-timing-iters=[num_iters]
Number of averaging timing iterations
used to select kernels
--workspace-size=[workspace_size] Maximum size of workspace given to
TensorRT
--max-batch-size=[max_batch_size] Maximum batch size (must be >= 1 to be
set, 0 means not set)
-t[threshold],
--threshold=[threshold] Maximum acceptable numerical deviation
from standard torchscript output
(default 2e-5)
--save-engine Instead of compiling a full a
TorchScript program, save the created
engine to the path specified as the
output path
input_file_path Path to input TorchScript file
output_file_path Path for compiled TorchScript (or
TensorRT engine) file
input_shapes... Sizes for inputs to engine, can either
be a single size or a range defined by
Min, Optimal, Max sizes, e.g.
"(N,..,C,H,W)"
"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]"
"--" can be used to terminate flag options and force all following
arguments to be treated as positional options
-h, --help Display this help menu
Verbiosity of the compiler
-v, --verbose Dumps debugging information about the
compilation process onto the console
-w, --warnings Disables warnings generated during
compilation onto the console (warnings
are on by default)
--i, --info Dumps info messages generated during
compilation onto the console
--build-debuggable-engine Creates a debuggable engine
--use-strict-types Restrict operating type to only use set
operation precision
--allow-gpu-fallback (Only used when targeting DLA
(device-type)) Lets engine run layers on
GPU if they are not supported on DLA
--disable-tf32 Prevent Float32 layers from using the
TF32 data format
-p[precision...],
--enabled-precison=[precision...] (Repeatable) Enabling an operating
precision for kernels to use when
building the engine (Int8 requires a
calibration-cache argument) [ float |
float32 | f32 | half | float16 | f16 |
int8 | i8 ] (default: float)
-d[type], --device-type=[type] The type of device the engine should be
built for [ gpu | dla ] (default: gpu)
--gpu-id=[gpu_id] GPU id if running on multi-GPU platform
(defaults to 0)
--dla-core=[dla_core] DLACore id if running on available DLA
(defaults to 0)
--engine-capability=[capability] The type of device the engine should be
built for [ default | safe_gpu |
safe_dla ]
--calibration-cache-file=[file_path]
Path to calibration cache file to use
for post training quantization
--num-min-timing-iter=[num_iters] Number of minimization timing iterations
used to select kernels
--num-avg-timing-iters=[num_iters]
Number of averaging timing iterations
used to select kernels
--workspace-size=[workspace_size] Maximum size of workspace given to
TensorRT
--max-batch-size=[max_batch_size] Maximum batch size (must be >= 1 to be
set, 0 means not set)
-t[threshold],
--threshold=[threshold] Maximum acceptable numerical deviation
from standard torchscript output
(default 2e-5)
--save-engine Instead of compiling a full a
TorchScript program, save the created
engine to the path specified as the
output path
input_file_path Path to input TorchScript file
output_file_path Path for compiled TorchScript (or
TensorRT engine) file
input_specs... Specs for inputs to engine, can either
be a single size or a range defined by
Min, Optimal, Max sizes, e.g.
"(N,..,C,H,W)"
"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]".
Data Type and format can be specified by
adding an "@" followed by dtype and "%"
followed by format to the end of the
shape spec. e.g. "(3, 3, 32,
32)@f16%NHWC"
"--" can be used to terminate flag options and force all following
arguments to be treated as positional options
e.g.

.. code-block:: shell
trtorchc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]" -p f16
trtorchc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@f16%contiguous" -p f16
4 changes: 2 additions & 2 deletions docsrc/tutorials/use_from_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ at the documentation for the TRTorch ``TensorRTCompileSpec`` API.
spec = {
"forward":
trtorch.TensorRTCompileSpec({
"input_shapes": [[1, 3, 300, 300]],
"op_precision": torch.half,
"inputs": [trtorch.Input([1, 3, 300, 300])],
"enabled_precisions": {torch.float, torch.half},
"refit": False,
"debug": False,
"strict_types": False,
Expand Down
8 changes: 4 additions & 4 deletions docsrc/tutorials/using_dla.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ Using DLA with trtorchc
Using DLA in a C++ application

.. code-block:: shell
.. code-block:: c++

std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
auto compile_spec = trtorch::CompileSpec({input_shape});

# Set a precision. DLA supports fp16 or int8 only
compile_spec.op_precision = torch::kF16;
compile_spec.enabled_precisions = {torch::kF16};
compile_spec.device.device_type = trtorch::CompileSpec::DeviceType::kDLA;

# Make sure the gpu id is set to Xavier id for DLA
Expand All @@ -42,14 +42,14 @@ Using DLA in a python application
.. code-block:: shell
compile_spec = {
"input_shapes": [self.input.shape],
"inputs": [trtorch.Input(self.input.shape)],
"device": {
"device_type": trtorch.DeviceType.DLA,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": True
},
"op_precision": torch.half
"enalbed_precisions": {torch.half}
}
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
Loading

0 comments on commit 90c28f7

Please sign in to comment.