Skip to content

Commit

Permalink
Fix up
Browse files Browse the repository at this point in the history
  • Loading branch information
Tabrizian committed Sep 23, 2024
1 parent 7bf6d9f commit c42afe1
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 68 deletions.
13 changes: 8 additions & 5 deletions src/pb_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -820,8 +820,7 @@ Stub::ProcessReturnedResponses(
std::string(py::str(py_responses[i].get_type())) + "'.");
}

std::shared_ptr<InferResponse> response =
py_responses[i].cast<std::shared_ptr<InferResponse>>();
InferResponse* response = py_responses[i].cast<InferResponse*>();
request->GetResponseSender()->UpdateStateAndCounters(
response, TRITONSERVER_RESPONSE_COMPLETE_FINAL);
}
Expand All @@ -845,9 +844,13 @@ Stub::ProcessReturnedResponses(
// Check the return type of execute function.
InferRequest* infer_request = py_requests[i].cast<InferRequest*>();
InferResponse* infer_response = py_responses[i].cast<InferResponse*>();
infer_response->PruneOutputTensors(infer_request->RequestedOutputNames());
ProcessResponse(infer_response);
responses_shm_handle[i] = infer_response->ShmHandle();
if (!py::isinstance<py::none>(py_responses[i])) {
infer_response->PruneOutputTensors(infer_request->RequestedOutputNames());
ProcessResponse(infer_response);
responses_shm_handle[i] = infer_response->ShmHandle();
} else {
responses_shm_handle[i] = 0;
}
}
response_batch_shm_ptr->batch_size = requests_size;
}
Expand Down
16 changes: 9 additions & 7 deletions src/python_be.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1026,13 +1026,7 @@ ModelInstanceState::SendMessageAndReceiveResponse(
SendMessageToStub(message);

bi::managed_external_buffer::handle_t response_message;
auto error = Stub()->ReceiveMessageFromStub(response_message);
if (error != nullptr) {
RespondErrorToAllRequests(
TRITONSERVER_ErrorMessage(error), responses, requests, request_count);

return;
}
Stub()->ReceiveMessageFromStub(response_message);

response = response_message;
}
Expand Down Expand Up @@ -1355,6 +1349,14 @@ ModelInstanceState::ProcessRequests(
(*responses)[r] = nullptr;
continue;
}

if (response_shm_handle[r] == 0) {
LOG_IF_ERROR(
TRITONBACKEND_ResponseDelete((*responses)[r]),
"failed to delete response");
(*responses)[r] = nullptr;
continue;
}
infer_response = InferResponse::LoadFromSharedMemory(
Stub()->ShmPool(), response_shm_handle[r],
false /* open_cuda_handle */);
Expand Down
5 changes: 3 additions & 2 deletions src/response_sender.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ ResponseSender::~ResponseSender()

void
ResponseSender::UpdateStateAndCounters(
const std::shared_ptr<InferResponse>& response, const uint32_t flags)
InferResponse* response, const uint32_t flags)
{
if (is_decoupled_ == nullptr) {
// TODO: Can a model access the response sender on a BLS infer request?
Expand Down Expand Up @@ -106,6 +106,7 @@ ResponseSender::UpdateStateAndCounters(
}

if (flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
response_factory_deleted_.exchange(true);
closed_ = true;
}
number_of_response_sent_++;
Expand All @@ -123,7 +124,7 @@ ResponseSender::Send(
py::gil_scoped_release release;

CheckResponseSenderArguments(infer_response, flags);
UpdateStateAndCounters(infer_response, flags);
UpdateStateAndCounters(infer_response.get(), flags);
if (infer_response) {
infer_response->PruneOutputTensors(requested_output_names_);
}
Expand Down
3 changes: 1 addition & 2 deletions src/response_sender.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ class ResponseSender {
~ResponseSender();
void Send(std::shared_ptr<InferResponse> response, const uint32_t flags);
bool IsCancelled();
void UpdateStateAndCounters(
const std::shared_ptr<InferResponse>& response, const uint32_t flags);
void UpdateStateAndCounters(InferResponse* response, const uint32_t flags);

// Can be useful at stopping the model from sending any more responses.
void Close();
Expand Down
53 changes: 3 additions & 50 deletions src/stub_launcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ StubLauncher::ModelInstanceStubProcess()
stub_message_queue_->Push(initialize_message->ShmHandle());

bi::managed_external_buffer::handle_t message;
RETURN_IF_ERROR(ReceiveMessageFromStub(message));
ReceiveMessageFromStub(message);

std::unique_ptr<IPCMessage> initialize_response_message =
IPCMessage::LoadFromSharedMemory(shm_pool_, message);
Expand Down Expand Up @@ -724,58 +724,11 @@ StubLauncher::KillStubProcess()
#endif
}

TRITONSERVER_Error*
void
StubLauncher::ReceiveMessageFromStub(
bi::managed_external_buffer::handle_t& message)
{
bool success = false;
while (!success) {
uint64_t timeout_miliseconds = 1000;
{
boost::posix_time::ptime timeout =
boost::get_system_time() +
boost::posix_time::milliseconds(timeout_miliseconds);

bi::scoped_lock<bi::interprocess_mutex> lock(*health_mutex_, timeout);

// Check if lock has been acquired.
if (lock) {
ipc_control_->stub_health = false;
} else {
// If it failed to obtain the lock, it means that the stub has been
// stuck or exited while holding the health mutex lock.
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL, "Failed to obtain the health mutex.");
}
}

message = parent_message_queue_->Pop(
timeout_miliseconds /* duration ms */, success);

bool is_stub_alive = false;
{
boost::posix_time::ptime timeout =
boost::get_system_time() + boost::posix_time::seconds(1);
bi::scoped_lock<bi::interprocess_mutex> lock(*health_mutex_, timeout);
if (lock) {
is_stub_alive = ipc_control_->stub_health;
} else {
// If It failed to obtain the lock, it means that the stub has been
// stuck or exited while holding the health mutex lock.
is_stub_alive = false;
}
}

if (!success && !is_stub_alive) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
(std::string("Stub process '") + model_instance_name_ +
"' is not healthy.")
.c_str());
}
}

return nullptr; // success
message = parent_message_queue_->Pop();
}

void
Expand Down
3 changes: 1 addition & 2 deletions src/stub_launcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ class StubLauncher {
void KillStubProcess();

// Get a message from the stub process
TRITONSERVER_Error* ReceiveMessageFromStub(
bi::managed_external_buffer::handle_t& message);
void ReceiveMessageFromStub(bi::managed_external_buffer::handle_t& message);

// Wait for stub process
void WaitForStubProcess();
Expand Down

0 comments on commit c42afe1

Please sign in to comment.