From 4d1bcc67f444f575ca129456e2a1294a0f30d233 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Tue, 17 Oct 2023 17:22:39 -0700 Subject: [PATCH 1/2] Add parameters support to InferenceRequest --- src/pb_stub.cc | 26 ++++++++++++++++++++++---- src/request_executor.cc | 30 ++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/src/pb_stub.cc b/src/pb_stub.cc index b38f8d38..f6987ba2 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -1464,15 +1464,32 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) const int64_t model_version, const uint32_t flags, const int32_t timeout, const PreferredMemory& preferred_memory, - const InferenceTrace& trace) { + const InferenceTrace& trace, const py::dict& parameters) { std::set requested_outputs; for (auto& requested_output_name : requested_output_names) { requested_outputs.emplace(requested_output_name); } - // FIXME: InferenceRequest parameters are not supported in BLS now. + for (const auto& pair : parameters) { + if (!py::isinstance(pair.first)) { + throw PythonBackendException( + "Expect parameters keys to have type str, found type " + + std::string(py::str(pair.first.get_type()))); + } + if (!py::isinstance(pair.second) && + !py::isinstance(pair.second) && + !py::isinstance(pair.second)) { + throw PythonBackendException( + "Expect parameters values to have type bool/int/str, found " + "type " + + std::string(py::str(pair.second.get_type()))); + } + } + py::module_ py_json = py::module_::import("json"); + std::string parameters_str = + py::str(py_json.attr("dumps")(parameters)); return std::make_shared( request_id, correlation_id, inputs, requested_outputs, - model_name, model_version, "" /*parameters*/, flags, timeout, + model_name, model_version, parameters_str, flags, timeout, 0 /*response_factory_address*/, 0 /*request_address*/, preferred_memory, trace); }), @@ -1485,7 +1502,8 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) py::arg("flags").none(false) = 0, py::arg("timeout").none(false) = 0, py::arg("preferred_memory").none(false) = PreferredMemory(PreferredMemory::DEFAULT, 0), - py::arg("trace").none(false) = InferenceTrace()) + py::arg("trace").none(false) = InferenceTrace(), + py::arg("parameters").none(false) = py::dict()) .def( "inputs", &InferRequest::Inputs, py::return_value_policy::reference_internal) diff --git a/src/request_executor.cc b/src/request_executor.cc index b54e3988..2a6d9575 100644 --- a/src/request_executor.cc +++ b/src/request_executor.cc @@ -365,6 +365,36 @@ RequestExecutor::Infer( infer_request->Trace().triton_trace_, &trace)); } + const std::string& param_str = infer_request->Parameters(); + triton::common::TritonJson::Value param; + THROW_IF_TRITON_ERROR(param.Parse(param_str.c_str(), param_str.length())); + std::vector param_keys; + THROW_IF_TRITON_ERROR(param.Members(¶m_keys)); + for (const auto& key : param_keys) { + triton::common::TritonJson::Value value; + if (!param.Find(key.c_str(), &value)) { + throw PythonBackendException("Unexpected missing key on parameters"); + } + if (value.IsString()) { + std::string string_value; + THROW_IF_TRITON_ERROR(value.AsString(&string_value)); + THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetStringParameter( + irequest, key.c_str(), string_value.c_str())); + } else if (value.IsInt()) { + int64_t int_value = 0; + THROW_IF_TRITON_ERROR(value.AsInt(&int_value)); + THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetIntParameter( + irequest, key.c_str(), int_value)); + } else if (value.IsBool()) { + bool bool_value = false; + THROW_IF_TRITON_ERROR(value.AsBool(&bool_value)); + THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetBoolParameter( + irequest, key.c_str(), bool_value)); + } else { + throw PythonBackendException("Unsupported value type on parameters"); + } + } + for (auto& infer_input : infer_request->Inputs()) { THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestAddInput( irequest, infer_input->Name().c_str(), From f25ffdb6f0152e4c2f9c23f979ada3fb0b846619 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:27:05 -0700 Subject: [PATCH 2/2] Safeguard default argument against mutation --- src/pb_stub.cc | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/src/pb_stub.cc b/src/pb_stub.cc index f6987ba2..bc929525 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -76,6 +76,27 @@ SignalHandler(int signum) // Skip the SIGINT and SIGTERM } +template +PYTYPE +PyDefaultArgumentToMutableType(const py::object& argument) +{ + // The default argument on Python functions always reference the same copy, + // meaning if the default argument is changed by the function, then it is + // changed for all subsequent calls to the function. Thus, default arguments + // should be limited to basic types (i.e. None). This helper function returns + // an empty expected type, if the argument is None (i.e. default initialized). + // If the argument is neither None nor expected type, an exception is thrown. + if (py::isinstance(argument)) { + return PYTYPE(); + } + if (py::isinstance(argument)) { + return argument; + } + throw PythonBackendException( + std::string("Expect ") + typeid(PYTYPE).name() + ", got " + + std::string(py::str(argument.get_type()))); +} + void Stub::Instantiate( int64_t shm_growth_size, int64_t shm_default_size, @@ -1464,7 +1485,10 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) const int64_t model_version, const uint32_t flags, const int32_t timeout, const PreferredMemory& preferred_memory, - const InferenceTrace& trace, const py::dict& parameters) { + const InferenceTrace& trace, + const py::object& parameters_) { + py::dict parameters = + PyDefaultArgumentToMutableType(parameters_); std::set requested_outputs; for (auto& requested_output_name : requested_output_names) { requested_outputs.emplace(requested_output_name); @@ -1503,7 +1527,7 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) py::arg("preferred_memory").none(false) = PreferredMemory(PreferredMemory::DEFAULT, 0), py::arg("trace").none(false) = InferenceTrace(), - py::arg("parameters").none(false) = py::dict()) + py::arg("parameters").none(true) = py::none()) .def( "inputs", &InferRequest::Inputs, py::return_value_policy::reference_internal)