diff --git a/CMakeLists.txt b/CMakeLists.txt index d626bdf0..5da6ecd5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -108,6 +108,7 @@ add_library( src/tensorrt_model_instance.h src/tensorrt_utils.cc src/tensorrt_utils.h + src/semaphore.h src/loader.cc src/loader.h src/logging.cc diff --git a/src/semaphore.h b/src/semaphore.h new file mode 100644 index 00000000..95de3d55 --- /dev/null +++ b/src/semaphore.h @@ -0,0 +1,57 @@ +// Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include + +namespace triton { namespace backend { namespace tensorrt { + +class Semaphore { + public: + explicit Semaphore(const int count) : count_(count) {} + + void Release() + { + std::unique_lock lck(mtx_); + count_++; + cv_.notify_one(); + } + + void Acquire() + { + std::unique_lock lck(mtx_); + cv_.wait(lck, [this]() { return (count_ > 0); }); + count_--; + } + + private: + int count_; + + std::mutex mtx_; + std::condition_variable cv_; +}; + +}}} // namespace triton::backend::tensorrt diff --git a/src/tensorrt.cc b/src/tensorrt.cc index 2a9f2f0e..09e52190 100644 --- a/src/tensorrt.cc +++ b/src/tensorrt.cc @@ -27,6 +27,7 @@ #include #include "loader.h" #include "logging.h" +#include "semaphore.h" #include "tensorrt_model.h" #include "tensorrt_model_instance.h" #include "tensorrt_utils.h" @@ -221,6 +222,23 @@ class ModelState : public TensorRTModel { void DisableEngineSharing() { engine_sharing_ = false; } bool IsEngineSharingEnabled() { return engine_sharing_; } + struct SemaphoreContext { + SemaphoreContext() : next_sem_idx_(0) {} + + std::vector> semaphore_list_; + int next_sem_idx_; + }; + + std::map>& SemaphoreMap() + { + return semaphore_map_; + } + + std::unique_ptr& SemaphoreDeviceContext(const int device_id) + { + return semaphore_map_[device_id]; + } + private: ModelState(TRITONBACKEND_Model* triton_model); @@ -264,6 +282,9 @@ class ModelState : public TensorRTModel { std::shared_ptr>> device_engines_; bool engine_sharing_; + + // A map between device id to its semaphore context + std::map> semaphore_map_; }; TRITONSERVER_Error* @@ -976,7 +997,7 @@ class ModelInstanceState : public TensorRTModelInstance { ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance); - void RegisterContexts(); + void RegisterSemaphore(); TRITONSERVER_Error* InitStreamsAndEvents(); TRITONSERVER_Error* InitEventSet(bool busy_wait_events); TRITONSERVER_Error* DestroyEventSet(); @@ -1203,19 +1224,16 @@ class ModelInstanceState : public TensorRTModelInstance { // executions' event states. std::thread completion_thread_; - triton::common::SyncQueue context_queue_; - size_t next_context_idx_; - // The details needed by the completion thread to finalize the // response for a model execution. struct Payload { explicit Payload( size_t event_set_idx, TRITONBACKEND_Request** requests, - uint32_t request_count, size_t context_idx) + uint32_t request_count, size_t sem_idx) : event_set_idx_(event_set_idx), total_batch_size_(0), compute_start_ns_(0), compute_input_end_ns_(0), compute_output_start_ns_(0), requests_(requests), - request_count_(request_count), context_idx_(context_idx) + request_count_(request_count), sem_idx_(sem_idx) { } @@ -1234,7 +1252,7 @@ class ModelInstanceState : public TensorRTModelInstance { std::vector requests_list_; TRITONBACKEND_Request** requests_; uint32_t request_count_; - size_t context_idx_; + size_t sem_idx_; // All the generated InferenceResponse objects std::vector responses_; @@ -1360,7 +1378,7 @@ ModelInstanceState::Create( "' for model instance '" + (*state)->Name() + "'"); } - (*state)->RegisterContexts(); + (*state)->RegisterSemaphore(); RETURN_IF_ERROR((*state)->InitStreamsAndEvents()); RETURN_IF_ERROR(model_state->CreateEngine( (*state)->DeviceId(), (*state)->DLACoreId(), model_path, @@ -1579,9 +1597,11 @@ ModelInstanceState::ProcessRequests( std::to_string(request_count) + " requests") .c_str()); - auto context_idx = next_context_idx_; + auto& sem_context = (model_state_->SemaphoreDeviceContext(DeviceId())); + + auto sem_idx = sem_context->next_sem_idx_; - Run(requests, request_count, context_idx); + Run(requests, request_count, sem_idx); bool run_failed = true; for (size_t i = 0; i < request_count; ++i) { @@ -1597,7 +1617,7 @@ ModelInstanceState::ProcessRequests( if (run_failed) { // On inference error, place the slot back to the queue // immediately as all works for the slot should be ignored. - context_queue_.Put(context_idx); + sem_context->semaphore_list_[sem_idx]->Release(); } else { auto event_set_idx = next_set_; next_set_ = (event_set_idx + 1) % EVENT_SET_COUNT; @@ -1620,7 +1640,9 @@ ModelInstanceState::ProcessRequests( } // Block the execution if there are no available contexts. - next_context_idx_ = context_queue_.Get(); + sem_context->next_sem_idx_ = + (sem_idx + 1) % sem_context->semaphore_list_.size(); + sem_context->semaphore_list_[sem_idx]->Acquire(); } void @@ -2528,7 +2550,9 @@ ModelInstanceState::ProcessResponse() // slots so that it can begin enqueuing new memcpys into the input // buffers cudaEventSynchronize(event_set.ready_for_input_); - context_queue_.Put(payload->context_idx_); + (model_state_->SemaphoreDeviceContext(DeviceId())) + ->semaphore_list_[payload->sem_idx_] + ->Release(); NVTX_MARKER("plan_input_available"); // Call Finalize() here to defer CUDA synchronization as much as @@ -2963,22 +2987,28 @@ ModelInstanceState::DestroyEventSet() } void -ModelInstanceState::RegisterContexts() +ModelInstanceState::RegisterSemaphore() { - size_t context_idx = 0; - context_queue_.Put(context_idx++); - // If eager batching is set, we add additional slots per device + // If eager batching is set, we add to the semaphore resource count // which allows to start preparing next batch before the previous - // batch has completed. The number of duplicates are limitedby + // batch has completed. The number of duplicates are limited by // number of event sets to prevent too many iterations are run // ahead and to avoid interference of the event communication in // the previous execution - if (model_state_->EagerBatching()) { - for (int count = 1; count < EVENT_SET_COUNT; ++count) { - context_queue_.Put(context_idx++); - } + int sem_count = (model_state_->EagerBatching()) ? EVENT_SET_COUNT : 1; + auto it = (model_state_->SemaphoreMap()).find(DeviceId()); + if (it == (model_state_->SemaphoreMap()).end()) { + it = (model_state_->SemaphoreMap()) + .emplace( + std::make_pair(DeviceId(), new ModelState::SemaphoreContext())) + .first; + } + it->second->semaphore_list_.emplace_back(new Semaphore(sem_count)); + + if (it->second->semaphore_list_.size() == 1) { + // Need to acquire a semaphore for first inference request + it->second->semaphore_list_[it->second->next_sem_idx_]->Acquire(); } - next_context_idx_ = context_queue_.Get(); } TRITONSERVER_Error*