Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for correlation string in BLS calls in python backend #339

Closed
wants to merge 10 commits into from
43 changes: 37 additions & 6 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
@@ -39,13 +39,15 @@ namespace triton { namespace backend { namespace python {

InferRequest::InferRequest(
const std::string& request_id, uint64_t correlation_id,
const std::string& correlation_id_string,
const std::vector<std::shared_ptr<PbTensor>>& inputs,
const std::set<std::string>& requested_output_names,
const std::string& model_name, const int64_t model_version,
const std::string& parameters, const uint32_t flags, const uint64_t timeout,
const intptr_t response_factory_address, const intptr_t request_address,
const PreferredMemory& preferred_memory, const InferenceTrace& trace)
: request_id_(request_id), correlation_id_(correlation_id), inputs_(inputs),
: request_id_(request_id), correlation_id_(correlation_id),
correlation_id_string_(correlation_id_string), inputs_(inputs),
requested_output_names_(requested_output_names), model_name_(model_name),
model_version_(model_version), parameters_(parameters), flags_(flags),
timeout_(timeout), response_factory_address_(response_factory_address),
@@ -103,6 +105,12 @@ InferRequest::CorrelationId()
return correlation_id_;
}

const std::string&
InferRequest::CorrelationIdString()
{
return correlation_id_string_;
}

const std::set<std::string>&
InferRequest::RequestedOutputNames()
{
@@ -199,6 +207,7 @@ InferRequest::SaveToSharedMemory(std::unique_ptr<SharedMemoryManager>& shm_pool)
(Inputs().size() * sizeof(bi::managed_external_buffer::handle_t)) +
PbString::ShmStructSize(ModelName()) +
PbString::ShmStructSize(RequestId()) +
PbString::ShmStructSize(CorrelationIdString()) +
PbString::ShmStructSize(Parameters()));

infer_request_shm_ptr_ =
@@ -264,8 +273,16 @@ InferRequest::SaveToSharedMemory(std::unique_ptr<SharedMemoryManager>& shm_pool)
reinterpret_cast<char*>(infer_request_shm_ptr_) + request_id_offset,
infer_request_shm.handle_ + request_id_offset);

size_t parameters_offset =
size_t correlation_id_string_offset =
request_id_offset + PbString::ShmStructSize(RequestId());
std::unique_ptr<PbString> correlation_id_string_shm = PbString::Create(
CorrelationIdString(),
reinterpret_cast<char*>(infer_request_shm_ptr_) +
correlation_id_string_offset,
infer_request_shm.handle_ + correlation_id_string_offset);

size_t parameters_offset = correlation_id_string_offset +
PbString::ShmStructSize(CorrelationIdString());
std::unique_ptr<PbString> parameters_shm = PbString::Create(
Parameters(),
reinterpret_cast<char*>(infer_request_shm_ptr_) + parameters_offset,
@@ -274,6 +291,7 @@ InferRequest::SaveToSharedMemory(std::unique_ptr<SharedMemoryManager>& shm_pool)
// Save the references to shared memory.
infer_request_shm_ = std::move(infer_request_shm);
request_id_shm_ = std::move(request_id_shm);
correlation_id_string_shm_ = std::move(correlation_id_string_shm);
model_name_shm_ = std::move(model_name_shm);
parameters_shm_ = std::move(parameters_shm);
shm_handle_ = infer_request_shm_.handle_;
@@ -336,25 +354,37 @@ InferRequest::LoadFromSharedMemory(
request_handle + request_id_offset,
reinterpret_cast<char*>(infer_request_shm_ptr) + request_id_offset);

size_t parameters_offset = request_id_offset + request_id_shm->Size();
size_t correlation_id_string_offset =
request_id_offset + request_id_shm->Size();
std::unique_ptr<PbString> correlation_id_string_shm =
PbString::LoadFromSharedMemory(
request_handle + correlation_id_string_offset,
reinterpret_cast<char*>(infer_request_shm_ptr) +
correlation_id_string_offset);

size_t parameters_offset =
correlation_id_string_offset + correlation_id_string_shm->Size();
std::unique_ptr<PbString> parameters_shm = PbString::LoadFromSharedMemory(
request_handle + request_id_offset,
request_handle + parameters_offset,
reinterpret_cast<char*>(infer_request_shm_ptr) + parameters_offset);

return std::unique_ptr<InferRequest>(new InferRequest(
infer_request_shm, request_id_shm, requested_output_names_shm,
model_name_shm, input_tensors, parameters_shm));
infer_request_shm, request_id_shm, correlation_id_string_shm,
requested_output_names_shm, model_name_shm, input_tensors,
parameters_shm));
}

InferRequest::InferRequest(
AllocatedSharedMemory<char>& infer_request_shm,
std::unique_ptr<PbString>& request_id_shm,
std::unique_ptr<PbString>& correlation_id_string_shm,
std::vector<std::unique_ptr<PbString>>& requested_output_names_shm,
std::unique_ptr<PbString>& model_name_shm,
std::vector<std::shared_ptr<PbTensor>>& input_tensors,
std::unique_ptr<PbString>& parameters_shm)
: infer_request_shm_(std::move(infer_request_shm)),
request_id_shm_(std::move(request_id_shm)),
correlation_id_string_shm_(std::move(correlation_id_string_shm)),
requested_output_names_shm_(std::move(requested_output_names_shm)),
model_name_shm_(std::move(model_name_shm)),
parameters_shm_(std::move(parameters_shm))
@@ -382,6 +412,7 @@ InferRequest::InferRequest(
}

request_id_ = request_id_shm_->String();
correlation_id_string_ = correlation_id_string_shm_->String();
parameters_ = parameters_shm_->String();
requested_output_names_ = std::move(requested_output_names);
model_name_ = model_name_shm_->String();
5 changes: 5 additions & 0 deletions src/infer_request.h
Original file line number Diff line number Diff line change
@@ -80,6 +80,7 @@ class InferRequest {
public:
InferRequest(
const std::string& request_id, uint64_t correlation_id,
const std::string& correlation_id_string,
const std::vector<std::shared_ptr<PbTensor>>& inputs,
const std::set<std::string>& requested_output_names,
const std::string& model_name, const int64_t model_version,
@@ -94,6 +95,7 @@ class InferRequest {
const std::string& RequestId();
const std::string& Parameters();
uint64_t CorrelationId();
const std::string& CorrelationIdString();
const std::string& ModelName();
int64_t ModelVersion();
uint32_t Flags();
@@ -141,13 +143,15 @@ class InferRequest {
InferRequest(
AllocatedSharedMemory<char>& infer_request_shm,
std::unique_ptr<PbString>& request_id_shm,
std::unique_ptr<PbString>& correlation_id_string_shm,
std::vector<std::unique_ptr<PbString>>& requested_output_names_shm,
std::unique_ptr<PbString>& model_name_shm,
std::vector<std::shared_ptr<PbTensor>>& input_tensors,
std::unique_ptr<PbString>& parameters_shm);

std::string request_id_;
uint64_t correlation_id_;
std::string correlation_id_string_;
std::vector<std::shared_ptr<PbTensor>> inputs_;
std::set<std::string> requested_output_names_;
std::string model_name_;
@@ -167,6 +171,7 @@ class InferRequest {
InferRequestShm* infer_request_shm_ptr_;

std::unique_ptr<PbString> request_id_shm_;
std::unique_ptr<PbString> correlation_id_string_shm_;
std::vector<std::unique_ptr<PbString>> requested_output_names_shm_;
std::unique_ptr<PbString> model_name_shm_;
bi::managed_external_buffer::handle_t* output_names_handle_shm_ptr_;
25 changes: 20 additions & 5 deletions src/pb_stub.cc
Original file line number Diff line number Diff line change
@@ -1615,7 +1615,8 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
py::class_<InferRequest, std::shared_ptr<InferRequest>>(
module, "InferenceRequest")
.def(
py::init([](const std::string& request_id, uint64_t correlation_id,
py::init([](const std::string& request_id,
const py::object& correlation_id,
const std::vector<std::shared_ptr<PbTensor>>& inputs,
const std::vector<std::string>& requested_output_names,
const std::string& model_name,
@@ -1648,11 +1649,24 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
py::module_ py_json = py::module_::import("json");
std::string parameters_str =
py::str(py_json.attr("dumps")(parameters));

uint64_t correlation_id_int = 0;
std::string correlation_id_str = "";

if (py::isinstance<py::int_>(correlation_id)) {
correlation_id_int = py::cast<uint64_t>(correlation_id);
} else if (py::isinstance<py::str>(correlation_id)) {
correlation_id_str = py::cast<std::string>(correlation_id);
} else {
throw PythonBackendException(
"Correlation ID must be integer or string");
}

return std::make_shared<InferRequest>(
request_id, correlation_id, inputs, requested_outputs,
model_name, model_version, parameters_str, flags, timeout,
0 /*response_factory_address*/, 0 /*request_address*/,
preferred_memory, trace);
request_id, correlation_id_int, correlation_id_str, inputs,
requested_outputs, model_name, model_version, parameters_str,
flags, timeout, 0 /*response_factory_address*/,
0 /*request_address*/, preferred_memory, trace);
}),
py::arg("request_id").none(false) = "",
py::arg("correlation_id").none(false) = 0,
@@ -1670,6 +1684,7 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
py::return_value_policy::reference_internal)
.def("request_id", &InferRequest::RequestId)
.def("correlation_id", &InferRequest::CorrelationId)
.def("correlation_id_string", &InferRequest::CorrelationIdString)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's only have a single correlation_id field and it should return string or int depending on which one has been set.

.def("flags", &InferRequest::Flags)
.def("set_flags", &InferRequest::SetFlags)
.def("timeout", &InferRequest::Timeout)
26 changes: 16 additions & 10 deletions src/python_be.cc
Original file line number Diff line number Diff line change
@@ -362,9 +362,14 @@ ModelInstanceState::SaveRequestsToSharedMemory(
const char* id;
RETURN_IF_ERROR(TRITONBACKEND_RequestId(request, &id));

uint64_t correlation_id;
RETURN_IF_ERROR(
TRITONBACKEND_RequestCorrelationId(request, &correlation_id));
uint64_t correlation_id = 0;
const char* correlation_id_string = "";

auto error = TRITONBACKEND_RequestCorrelationId(request, &correlation_id);
if (error != nullptr) {
RETURN_IF_ERROR(TRITONBACKEND_RequestCorrelationIdString(
request, &correlation_id_string));
}

uint32_t flags;
RETURN_IF_ERROR(TRITONBACKEND_RequestFlags(request, &flags));
@@ -390,17 +395,18 @@ ModelInstanceState::SaveRequestsToSharedMemory(
RETURN_IF_ERROR(TRITONBACKEND_ResponseFactoryNew(&factory_ptr, request));

infer_request = std::make_unique<InferRequest>(
id, correlation_id, pb_input_tensors, requested_output_names,
model_state->Name(), model_state->Version(), parameters_string, flags,
request_timeout, reinterpret_cast<intptr_t>(factory_ptr),
id, correlation_id, correlation_id_string, pb_input_tensors,
requested_output_names, model_state->Name(), model_state->Version(),
parameters_string, flags, request_timeout,
reinterpret_cast<intptr_t>(factory_ptr),
reinterpret_cast<intptr_t>(request),
PreferredMemory(PreferredMemory::kDefault, 0), trace);
} else {
infer_request = std::make_unique<InferRequest>(
id, correlation_id, pb_input_tensors, requested_output_names,
model_state->Name(), model_state->Version(), parameters_string, flags,
request_timeout, 0 /* response_factory_address */,
reinterpret_cast<intptr_t>(request),
id, correlation_id, correlation_id_string, pb_input_tensors,
requested_output_names, model_state->Name(), model_state->Version(),
parameters_string, flags, request_timeout,
0 /* response_factory_address */, reinterpret_cast<intptr_t>(request),
PreferredMemory(PreferredMemory::kDefault, 0), trace);
}

9 changes: 7 additions & 2 deletions src/request_executor.cc
Original file line number Diff line number Diff line change
@@ -354,8 +354,13 @@ RequestExecutor::Infer(
THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetId(
irequest, infer_request->RequestId().c_str()));

THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetCorrelationId(
irequest, infer_request->CorrelationId()));
if (infer_request->CorrelationIdString().empty()) {
THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetCorrelationId(
irequest, infer_request->CorrelationId()));
} else {
THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetCorrelationIdString(
irequest, infer_request->CorrelationIdString().c_str()));
}

THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetFlags(
irequest, infer_request->Flags()));