From 8dc1a0648ee7c07667d96e2a2871ef0219e3a8f1 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Wed, 30 Nov 2022 17:57:47 -0500 Subject: [PATCH] fix: Repair Citrinet-1024 compilation issues [Duplicate of PR #1488 for Release 1.3] (#1489) --- core/conversion/converters/impl/element_wise.cpp | 3 ++- core/conversion/converters/impl/reduce.cpp | 8 ++++++++ tests/core/conversion/converters/test_reduce.cpp | 14 ++++++++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index 4e1fab4929..a86307c682 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -325,7 +325,8 @@ auto element_wise_registrations TORCHTRT_UNUSED = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV, self, other, util::node_info(n)); } else if (rounding_mode == "trunc") { // trunc = floor(abs(div)) * sign(div) - auto tmp_div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, "tmp_div"); + auto tmp_div = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n) + "_tmp_div"); auto abs = add_abs(ctx, n, tmp_div->getOutput(0), util::node_info(n) + "_absolute_val"); // In this case, we allow the floor unary on non-TRT Unary types, as it is needed for this diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp index 03e6bd20ab..b3db09ffd7 100644 --- a/core/conversion/converters/impl/reduce.cpp +++ b/core/conversion/converters/impl/reduce.cpp @@ -113,6 +113,14 @@ auto reduce_registrations TORCHTRT_UNUSED = LOG_DEBUG("Keep dims: " << keepdim); LOG_WARNING("Sum converter disregards dtype"); + + if (in_tensor->getType() == nvinfer1::DataType::kBOOL) { + LOG_DEBUG( + "Found type " << in_tensor->getType() << " in aten::sum, casting to " + << nvinfer1::DataType::kINT32 << " for compatibility."); + in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32); + } + auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim); TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n); diff --git a/tests/core/conversion/converters/test_reduce.cpp b/tests/core/conversion/converters/test_reduce.cpp index e3e1e6d252..4699427d5e 100644 --- a/tests/core/conversion/converters/test_reduce.cpp +++ b/tests/core/conversion/converters/test_reduce.cpp @@ -5,6 +5,7 @@ #include "tests/util/util.h" #include "torch/csrc/jit/ir/irparser.h" #include "torch/csrc/jit/passes/common_subexpression_elimination.h" +#include "torch/torch.h" namespace { std::string gen_basic_graph(const std::string& op) { @@ -162,6 +163,19 @@ TEST(Converters, ATenSumDimNegOneIndexKeepDimsConvertsCorrectly) { test_body(graph, in); } +TEST(Converters, ATenSumDimNegOneIndexKeepDimsBoolTensorConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=-1]() + %2 : int[] = prim::ListConstruct(%1) + %3 : bool = prim::Constant[value=1]() + %4 : None = prim::Constant() + %5 : Tensor = aten::sum(%0, %2, %3, %4) + return (%5))IR"; + auto in = at::randint(0, 2, {4, 4, 4}, at::kCUDA).to(torch::kBool); + test_body(graph, in); +} + TEST(Converters, ATenSumDimNegIndexConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor):