From 2b69742d1ad73e01fbc9199a02b9b0cfddc3f4ed Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 9 Aug 2021 11:47:16 -0700 Subject: [PATCH] fix: Fix TRT8 engine capability flags Signed-off-by: Dheeraj Peri --- cpp/trtorchc/README.md | 4 ++-- cpp/trtorchc/main.cpp | 14 +++++++------- docsrc/tutorials/trtorchc.rst | 8 ++++---- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/cpp/trtorchc/README.md b/cpp/trtorchc/README.md index 863428132f..c4e254f129 100644 --- a/cpp/trtorchc/README.md +++ b/cpp/trtorchc/README.md @@ -59,8 +59,8 @@ trtorchc [input_file_path] [output_file_path] --dla-core=[dla_core] DLACore id if running on available DLA (defaults to 0) --engine-capability=[capability] The type of device the engine should be - built for [ default | safe_gpu | - safe_dla ] + built for [ standard | safety | + dla_standalone ] --calibration-cache-file=[file_path] Path to calibration cache file to use for post training quantization diff --git a/cpp/trtorchc/main.cpp b/cpp/trtorchc/main.cpp index 35d7b7d2cc..547b6dbc83 100644 --- a/cpp/trtorchc/main.cpp +++ b/cpp/trtorchc/main.cpp @@ -264,7 +264,7 @@ int main(int argc, char** argv) { args::ValueFlag engine_capability( parser, "capability", - "The type of device the engine should be built for [ default | safe_gpu | safe_dla ]", + "The type of device the engine should be built for [ standard | safety | dla_standalone ]", {"engine-capability"}); args::ValueFlag calibration_cache_file( @@ -537,12 +537,12 @@ int main(int argc, char** argv) { auto capability = args::get(engine_capability); std::transform( capability.begin(), capability.end(), capability.begin(), [](unsigned char c) { return std::tolower(c); }); - if (capability == "default") { - compile_settings.capability = trtorch::CompileSpec::EngineCapability::kDEFAULT; - } else if (capability == "safe_gpu") { - compile_settings.capability = trtorch::CompileSpec::EngineCapability::kSAFE_GPU; - } else if (capability == "safe_dla") { - compile_settings.capability = trtorch::CompileSpec::EngineCapability::kSAFE_DLA; + if (capability == "standard") { + compile_settings.capability = trtorch::CompileSpec::EngineCapability::kSTANDARD; + } else if (capability == "safety") { + compile_settings.capability = trtorch::CompileSpec::EngineCapability::kSAFETY; + } else if (capability == "dla_standalone") { + compile_settings.capability = trtorch::CompileSpec::EngineCapability::kDLA_STANDALONE; } else { trtorch::logging::log( trtorch::logging::Level::kERROR, "Invalid engine capability, options are [ default | safe_gpu | safe_dla ]"); diff --git a/docsrc/tutorials/trtorchc.rst b/docsrc/tutorials/trtorchc.rst index 4d7c232f95..9bf9809569 100644 --- a/docsrc/tutorials/trtorchc.rst +++ b/docsrc/tutorials/trtorchc.rst @@ -45,7 +45,7 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r --ffo, --forced-fallback-ops List of operators in the graph that should be forced to fallback to Pytorch for execution - + --disable-tf32 Prevent Float32 layers from using the TF32 data format -p[precision...], @@ -55,7 +55,7 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r calibration-cache argument) [ float | float32 | f32 | half | float16 | f16 | int8 | i8 ] (default: float) - + -d[type], --device-type=[type] The type of device the engine should be built for [ gpu | dla ] (default: gpu) --gpu-id=[gpu_id] GPU id if running on multi-GPU platform @@ -63,8 +63,8 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r --dla-core=[dla_core] DLACore id if running on available DLA (defaults to 0) --engine-capability=[capability] The type of device the engine should be - built for [ default | safe_gpu | - safe_dla ] + built for [ standard | safety | + dla_standalone ] --calibration-cache-file=[file_path] Path to calibration cache file to use for post training quantization