Skip to content

Commit

Permalink
fix: Fix TRT8 engine capability flags
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Aug 9, 2021
1 parent e336630 commit 2b69742
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions cpp/trtorchc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ trtorchc [input_file_path] [output_file_path]
--dla-core=[dla_core] DLACore id if running on available DLA
(defaults to 0)
--engine-capability=[capability] The type of device the engine should be
built for [ default | safe_gpu |
safe_dla ]
built for [ standard | safety |
dla_standalone ]
--calibration-cache-file=[file_path]
Path to calibration cache file to use
for post training quantization
Expand Down
14 changes: 7 additions & 7 deletions cpp/trtorchc/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ int main(int argc, char** argv) {
args::ValueFlag<std::string> engine_capability(
parser,
"capability",
"The type of device the engine should be built for [ default | safe_gpu | safe_dla ]",
"The type of device the engine should be built for [ standard | safety | dla_standalone ]",
{"engine-capability"});

args::ValueFlag<std::string> calibration_cache_file(
Expand Down Expand Up @@ -537,12 +537,12 @@ int main(int argc, char** argv) {
auto capability = args::get(engine_capability);
std::transform(
capability.begin(), capability.end(), capability.begin(), [](unsigned char c) { return std::tolower(c); });
if (capability == "default") {
compile_settings.capability = trtorch::CompileSpec::EngineCapability::kDEFAULT;
} else if (capability == "safe_gpu") {
compile_settings.capability = trtorch::CompileSpec::EngineCapability::kSAFE_GPU;
} else if (capability == "safe_dla") {
compile_settings.capability = trtorch::CompileSpec::EngineCapability::kSAFE_DLA;
if (capability == "standard") {
compile_settings.capability = trtorch::CompileSpec::EngineCapability::kSTANDARD;
} else if (capability == "safety") {
compile_settings.capability = trtorch::CompileSpec::EngineCapability::kSAFETY;
} else if (capability == "dla_standalone") {
compile_settings.capability = trtorch::CompileSpec::EngineCapability::kDLA_STANDALONE;
} else {
trtorch::logging::log(
trtorch::logging::Level::kERROR, "Invalid engine capability, options are [ default | safe_gpu | safe_dla ]");
Expand Down
8 changes: 4 additions & 4 deletions docsrc/tutorials/trtorchc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r
--ffo,
--forced-fallback-ops List of operators in the graph that
should be forced to fallback to Pytorch for execution
--disable-tf32 Prevent Float32 layers from using the
TF32 data format
-p[precision...],
Expand All @@ -55,16 +55,16 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r
calibration-cache argument) [ float |
float32 | f32 | half | float16 | f16 |
int8 | i8 ] (default: float)
-d[type], --device-type=[type] The type of device the engine should be
built for [ gpu | dla ] (default: gpu)
--gpu-id=[gpu_id] GPU id if running on multi-GPU platform
(defaults to 0)
--dla-core=[dla_core] DLACore id if running on available DLA
(defaults to 0)
--engine-capability=[capability] The type of device the engine should be
built for [ default | safe_gpu |
safe_dla ]
built for [ standard | safety |
dla_standalone ]
--calibration-cache-file=[file_path]
Path to calibration cache file to use
for post training quantization
Expand Down

0 comments on commit 2b69742

Please sign in to comment.