diff --git a/src/tensorrt.cc b/src/tensorrt.cc index df046e96..2a9f2f0e 100644 --- a/src/tensorrt.cc +++ b/src/tensorrt.cc @@ -3819,18 +3819,10 @@ ModelInstanceState::InitializeExecuteInputBinding( int io_index = engine_->getBindingIndex(input_name.c_str()); auto& io_binding_info = io_binding_infos_[next_buffer_binding_set_][io_index]; - if (io_binding_info.buffer_ != nullptr) { - if (!is_state) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("input '") + input_name + - "' has already appeared as an input or output for " + Name()) - .c_str()); - } else { - // The input bindings for the given input is already allocated, - // hence, no need to proceed further. - return nullptr; - } + if ((io_binding_info.buffer_ != nullptr) && is_state) { + // The input bindings for the given state input is already allocated, + // hence, no need to proceed further. + return nullptr; } for (auto& trt_context : trt_contexts_) { @@ -3849,6 +3841,14 @@ ModelInstanceState::InitializeExecuteInputBinding( return nullptr; } + if (io_binding_info.buffer_ != nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("input '") + input_name + + "' has already appeared as an input or output for " + Name()) + .c_str()); + } + if (!engine_->bindingIsInput(binding_index)) { return TRITONSERVER_ErrorNew( @@ -4096,16 +4096,10 @@ ModelInstanceState::InitializeExecuteOutputBinding( io_binding_info.is_requested_output_tensor_ = true; } - if (io_binding_info.buffer_ != nullptr) { - if (!is_state) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("output '") + output_name + - "' has already appeared as an input or output for " + Name()) - .c_str()); - } else { - return nullptr; - } + if ((io_binding_info.buffer_ != nullptr) && is_state) { + // The input bindings for the given state input is already allocated, + // hence, no need to proceed further. + return nullptr; } for (auto& trt_context : trt_contexts_) { @@ -4127,6 +4121,14 @@ ModelInstanceState::InitializeExecuteOutputBinding( .c_str()); } + if (io_binding_info.buffer_ != nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("output '") + output_name + + "' has already appeared as an input or output for " + Name()) + .c_str()); + } + TRITONSERVER_DataType dt = ConvertTrtTypeToDataType(engine_->getBindingDataType(binding_index)); TRITONSERVER_DataType config_dt =