Skip to content

Commit

Permalink
refactor: Apply linting
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jul 21, 2021
1 parent aaddaf1 commit 265f71e
Show file tree
Hide file tree
Showing 15 changed files with 207 additions and 116 deletions.
8 changes: 3 additions & 5 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,7 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
<< "please report this error to https://www.github.com/NVIDIA/TRTorch/issues");
}

void AddInputs(
ConversionCtx* ctx,
at::ArrayRef<const torch::jit::Value*> inputs,
std::vector<ir::Input>& input_specs) {
void AddInputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> inputs, std::vector<ir::Input>& input_specs) {
std::vector<const torch::jit::Value*> input_tensors;
for (auto in : inputs) {
// Disregarding inputs that are not tensors
Expand Down Expand Up @@ -164,7 +161,8 @@ void AddInputs(
std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
LOG_INFO(
ctx->logger,
"Adding Input " << in->debugName() << " (named: " << name << "): " << spec << " in engine (conversion.AddInputs)");
"Adding Input " << in->debugName() << " (named: " << name << "): " << spec
<< " in engine (conversion.AddInputs)");

auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape);
TRTORCH_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
Expand Down
3 changes: 1 addition & 2 deletions core/conversion/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ namespace conversion {
struct ConversionInfo {
std::vector<ir::Input> inputs;
BuilderSettings engine_settings;
ConversionInfo(std::vector<ir::Input> inputs)
: inputs(std::move(inputs)), engine_settings(BuilderSettings()) {}
ConversionInfo(std::vector<ir::Input> inputs) : inputs(std::move(inputs)), engine_settings(BuilderSettings()) {}
};

// TODO: REMOVE GRAPH AND PARAMS AND MOVE FULLY TO INLINED CONSTANTS
Expand Down
6 changes: 4 additions & 2 deletions core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
LOG_DEBUG(build_settings);
cfg = builder->createBuilderConfig();

for(auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {
for (auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {
switch (*p) {
case nvinfer1::DataType::kHALF:
TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does not support FP16");
Expand Down Expand Up @@ -121,7 +121,9 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
static_cast<int>(settings.device.dla_core) < nbDLACores,
"Configured DLA Core ID: " << settings.device.dla_core
<< " not available. Total number of available DLA Cores: " << nbDLACores);
TRTORCH_CHECK(settings.enabled_precisions.find(nvinfer1::DataType::kFLOAT) == settings.enabled_precisions.end(), "DLA supports only fp16 or int8 precision");
TRTORCH_CHECK(
settings.enabled_precisions.find(nvinfer1::DataType::kFLOAT) == settings.enabled_precisions.end(),
"DLA supports only fp16 or int8 precision");
cfg->setDLACore(settings.device.dla_core);
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

#include <map>
#include <memory>
#include <unordered_map>
#include <set>
#include <unordered_map>

#include "NvInfer.h"
#include "torch/csrc/jit/ir/ir.h"
Expand Down
5 changes: 3 additions & 2 deletions core/conversion/converters/impl/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@ auto acthardtanh TRTORCH_UNUSED =
std::string pluginName = "CustomGeluPluginDynamic";
nvinfer1::PluginFieldCollection fc;
std::vector<nvinfer1::PluginField> f;
//REVIEW is this right?
int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) == ctx->settings.enabled_precisions.end()
// REVIEW is this right?
int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) ==
ctx->settings.enabled_precisions.end()
? 0
: 1; // Integer encoding the DataType (0: FP32, 1: FP16)
f.emplace_back(nvinfer1::PluginField("type_id", &type_id, nvinfer1::PluginFieldType::kINT32, 1));
Expand Down
22 changes: 18 additions & 4 deletions core/ir/Input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,20 @@ Input::Input(std::vector<int64_t> shape, nvinfer1::DataType dtype, nvinfer1::Ten

TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
this->dtype = dtype;
TRTORCH_CHECK(valid_dtype_format_combo(dtype, format), "Unsupported combination of dtype and tensor format: (" << dtype << ", " << format << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
TRTORCH_CHECK(
valid_dtype_format_combo(dtype, format),
"Unsupported combination of dtype and tensor format: ("
<< dtype << ", " << format
<< "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
this->format = format;
}

Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, nvinfer1::DataType dtype, nvinfer1::TensorFormat format) {
Input::Input(
std::vector<int64_t> min_shape,
std::vector<int64_t> opt_shape,
std::vector<int64_t> max_shape,
nvinfer1::DataType dtype,
nvinfer1::TensorFormat format) {
if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
LOG_WARNING("Verify that this dim size is accepted");
}
Expand Down Expand Up @@ -174,15 +183,20 @@ Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std

TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
this->dtype = dtype;
TRTORCH_CHECK(valid_dtype_format_combo(dtype, format), "Unsupported combination of dtype and tensor format: (" << dtype << ", " << format << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
TRTORCH_CHECK(
valid_dtype_format_combo(dtype, format),
"Unsupported combination of dtype and tensor format: ("
<< dtype << ", " << format
<< "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
this->format = format;
}

std::ostream& operator<<(std::ostream& os, const Input& input) {
if (!input.input_is_dynamic) {
os << "Input(shape: " << input.input_shape << ", dtype: " << input.dtype << ", format: " << input.format << ')';
} else {
os << "Input(shape: " << input.input_shape << ", min: " << input.min << ", opt: " << input.opt << ", max: " << input.max << ", dtype: " << input.dtype << ", format: " << input.format << ')';
os << "Input(shape: " << input.input_shape << ", min: " << input.min << ", opt: " << input.opt
<< ", max: " << input.max << ", dtype: " << input.dtype << ", format: " << input.format << ')';
}
return os;
}
Expand Down
18 changes: 13 additions & 5 deletions core/ir/ir.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
#pragma once

#include <vector>
#include <iostream>
#include <vector>
#include "NvInfer.h"

namespace trtorch {
namespace core {
namespace ir {

struct Input {
//Input(std::vector<int64_t> shape);
//Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
Input(std::vector<int64_t> shape, nvinfer1::DataType dtype=nvinfer1::DataType::kFLOAT, nvinfer1::TensorFormat format=nvinfer1::TensorFormat::kLINEAR);
Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, nvinfer1::DataType dtype=nvinfer1::DataType::kFLOAT, nvinfer1::TensorFormat format=nvinfer1::TensorFormat::kLINEAR);
// Input(std::vector<int64_t> shape);
// Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
Input(
std::vector<int64_t> shape,
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR);
Input(
std::vector<int64_t> min_shape,
std::vector<int64_t> opt_shape,
std::vector<int64_t> max_shape,
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR);
friend std::ostream& operator<<(std::ostream& os, const Input& input);

bool input_is_dynamic = false;
Expand Down
Loading

0 comments on commit 265f71e

Please sign in to comment.