Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parameters support to InferenceRequest #313

Merged
merged 2 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions src/pb_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,27 @@ SignalHandler(int signum)
// Skip the SIGINT and SIGTERM
}

template <typename PYTYPE>
PYTYPE
PyDefaultArgumentToMutableType(const py::object& argument)
{
// The default argument on Python functions always reference the same copy,
// meaning if the default argument is changed by the function, then it is
// changed for all subsequent calls to the function. Thus, default arguments
// should be limited to basic types (i.e. None). This helper function returns
// an empty expected type, if the argument is None (i.e. default initialized).
// If the argument is neither None nor expected type, an exception is thrown.
if (py::isinstance<py::none>(argument)) {
return PYTYPE();
}
if (py::isinstance<PYTYPE>(argument)) {
return argument;
}
throw PythonBackendException(
std::string("Expect ") + typeid(PYTYPE).name() + ", got " +
std::string(py::str(argument.get_type())));
}

void
Stub::Instantiate(
int64_t shm_growth_size, int64_t shm_default_size,
Expand Down Expand Up @@ -1464,15 +1485,35 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
const int64_t model_version, const uint32_t flags,
const int32_t timeout,
const PreferredMemory& preferred_memory,
const InferenceTrace& trace) {
const InferenceTrace& trace,
const py::object& parameters_) {
py::dict parameters =
PyDefaultArgumentToMutableType<py::dict>(parameters_);
std::set<std::string> requested_outputs;
for (auto& requested_output_name : requested_output_names) {
requested_outputs.emplace(requested_output_name);
}
// FIXME: InferenceRequest parameters are not supported in BLS now.
for (const auto& pair : parameters) {
if (!py::isinstance<py::str>(pair.first)) {
throw PythonBackendException(
"Expect parameters keys to have type str, found type " +
std::string(py::str(pair.first.get_type())));
}
if (!py::isinstance<py::bool_>(pair.second) &&
!py::isinstance<py::int_>(pair.second) &&
!py::isinstance<py::str>(pair.second)) {
throw PythonBackendException(
"Expect parameters values to have type bool/int/str, found "
"type " +
std::string(py::str(pair.second.get_type())));
}
}
py::module_ py_json = py::module_::import("json");
std::string parameters_str =
py::str(py_json.attr("dumps")(parameters));
return std::make_shared<InferRequest>(
request_id, correlation_id, inputs, requested_outputs,
model_name, model_version, "" /*parameters*/, flags, timeout,
model_name, model_version, parameters_str, flags, timeout,
0 /*response_factory_address*/, 0 /*request_address*/,
preferred_memory, trace);
}),
Expand All @@ -1485,7 +1526,8 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
py::arg("flags").none(false) = 0, py::arg("timeout").none(false) = 0,
py::arg("preferred_memory").none(false) =
PreferredMemory(PreferredMemory::DEFAULT, 0),
py::arg("trace").none(false) = InferenceTrace())
py::arg("trace").none(false) = InferenceTrace(),
py::arg("parameters").none(true) = py::none())
.def(
"inputs", &InferRequest::Inputs,
py::return_value_policy::reference_internal)
Expand Down
30 changes: 30 additions & 0 deletions src/request_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,36 @@ RequestExecutor::Infer(
infer_request->Trace().triton_trace_, &trace));
}

const std::string& param_str = infer_request->Parameters();
triton::common::TritonJson::Value param;
THROW_IF_TRITON_ERROR(param.Parse(param_str.c_str(), param_str.length()));
std::vector<std::string> param_keys;
THROW_IF_TRITON_ERROR(param.Members(&param_keys));
for (const auto& key : param_keys) {
triton::common::TritonJson::Value value;
if (!param.Find(key.c_str(), &value)) {
throw PythonBackendException("Unexpected missing key on parameters");
}
if (value.IsString()) {
std::string string_value;
THROW_IF_TRITON_ERROR(value.AsString(&string_value));
THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetStringParameter(
irequest, key.c_str(), string_value.c_str()));
} else if (value.IsInt()) {
int64_t int_value = 0;
THROW_IF_TRITON_ERROR(value.AsInt(&int_value));
THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetIntParameter(
irequest, key.c_str(), int_value));
} else if (value.IsBool()) {
bool bool_value = false;
THROW_IF_TRITON_ERROR(value.AsBool(&bool_value));
THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetBoolParameter(
kthui marked this conversation as resolved.
Show resolved Hide resolved
irequest, key.c_str(), bool_value));
} else {
throw PythonBackendException("Unsupported value type on parameters");
}
}

for (auto& infer_input : infer_request->Inputs()) {
THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestAddInput(
irequest, infer_input->Name().c_str(),
Expand Down
Loading