-
Notifications
You must be signed in to change notification settings - Fork 234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Client-side input shape/element validation #742
base: main
Are you sure you want to change the base?
Changes from 3 commits
3178f99
7210d00
9c2941b
b4c6a17
e5e6b7e
07059a6
8b699c0
60f3f52
2a5c507
6b56c3b
223b9d8
a9a2c1c
a584741
5889b8e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions | ||
|
@@ -26,6 +26,10 @@ | |
|
||
#include "common.h" | ||
|
||
#include <numeric> | ||
|
||
#include "triton/common/model_config.h" | ||
|
||
namespace triton { namespace client { | ||
|
||
//============================================================================== | ||
|
@@ -232,6 +236,113 @@ InferInput::SetBinaryData(const bool binary_data) | |
return Error::Success; | ||
} | ||
|
||
Error | ||
InferInput::GetStringCount(size_t* str_cnt) const | ||
{ | ||
int64_t str_checked = 0; | ||
size_t remaining_str_size = 0; | ||
|
||
size_t next_buf_idx = 0; | ||
const size_t buf_cnt = bufs_.size(); | ||
|
||
const uint8_t* buf = nullptr; | ||
size_t remaining_buf_size = 0; | ||
|
||
// Validate elements until all buffers have been fully processed. | ||
while (remaining_buf_size || next_buf_idx < buf_cnt) { | ||
yinggeh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Get the next buf if not currently processing one. | ||
if (!remaining_buf_size) { | ||
// Reset remaining buf size and pointers for next buf. | ||
buf = bufs_[next_buf_idx]; | ||
remaining_buf_size = buf_byte_sizes_[next_buf_idx]; | ||
next_buf_idx++; | ||
} | ||
|
||
constexpr size_t kStringSizeIndicator = sizeof(uint32_t); | ||
// Get the next element if not currently processing one. | ||
if (!remaining_str_size) { | ||
// FIXME: Assume the string element's byte size indicator is not spread | ||
// across buf boundaries for simplicity. Also needs better log msg. | ||
if (remaining_buf_size < kStringSizeIndicator) { | ||
return Error("element byte size indicator exceeds the end of the buf."); | ||
} | ||
|
||
// Start the next element and reset the remaining element size. | ||
remaining_str_size = *(reinterpret_cast<const uint32_t*>(buf)); | ||
str_checked++; | ||
|
||
// Advance pointer and remainder by the indicator size. | ||
buf += kStringSizeIndicator; | ||
remaining_buf_size -= kStringSizeIndicator; | ||
} | ||
|
||
// If the remaining buf fits it: consume the rest of the element, proceed | ||
// to the next element. | ||
if (remaining_buf_size >= remaining_str_size) { | ||
buf += remaining_str_size; | ||
remaining_buf_size -= remaining_str_size; | ||
remaining_str_size = 0; | ||
} | ||
// Otherwise the remaining element is larger: consume the rest of the | ||
// buf, proceed to the next buf. | ||
else { | ||
remaining_str_size -= remaining_buf_size; | ||
remaining_buf_size = 0; | ||
} | ||
} | ||
|
||
// FIXME: If more than expected, should stop earlier | ||
// Validate the number of processed elements exactly match expectations. | ||
*str_cnt = str_checked; | ||
return Error::Success; | ||
} | ||
|
||
Error | ||
InferInput::ValidateData() const | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving TRT reformat conversation to a thread 🧵 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yingge:
Sai:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Sai's changes will allow the check to work on core side, but probably not on client side, right? @yinggeh
You can query the platform/backend through the model config APIs on client side, which would work when inferring on a TensorRT model directly. You can probably even query For an ensemble model containing one of these TRT models with non-linear inputs, you may need to follow the ensemble definition to find out if it's calling a TRT model with its inputs, which can be a pain. It may be simpler to skip the check on ensemble models and let the core check handle it (but it feels like we're starting to introduce a lot of special checks and cases with this feature). For a BLS model, I think it's fine and will work as any other python model, then it will trigger the core check internally if the BLS is calling the TRT model. CC @tanmayv25 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another alternative is to introduce a new flag in client Input/Output tensors to skip the byte size check on the client side.
Cons:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@rmccorm4 Can you elaborate on this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If you have an ensemble with ENSEMBLE_INPUT0 where the first step is a TRT model with non-linear IO INPUT0 and a mapping of ENSEMBLE_INPUT0 -> INPUT0, do we require an ensemble config to mention that the ENSEMBLE_INPUT0 is non-linear IO too? Or is it inferred internally? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
+1 that I think this is counter-intuitive to the goal
If we are able to internally determine "the correct scenario" programatically, isn't this the same as being able to skip internally without user specification? |
||
inference::DataType datatype = | ||
triton::common::ProtocolStringToDataType(datatype_); | ||
if (io_type_ == SHARED_MEMORY) { | ||
if (datatype == inference::DataType::TYPE_STRING) { | ||
// TODO Didn't find any shm and BYTES inputs inference example | ||
} else { | ||
int64_t expected_byte_size = | ||
triton::common::GetByteSize(datatype, shape_); | ||
if ((int64_t)byte_size_ != expected_byte_size) { | ||
return Error( | ||
"'" + name_ + "' got unexpected byte size " + | ||
std::to_string(byte_size_) + ", expected " + | ||
std::to_string(expected_byte_size)); | ||
} | ||
} | ||
} else { | ||
if (datatype == inference::DataType::TYPE_STRING) { | ||
int64_t expected_str_cnt = triton::common::GetElementCount(shape_); | ||
size_t str_cnt; | ||
Error err = GetStringCount(&str_cnt); | ||
if (!err.IsOk()) { | ||
return err; | ||
} | ||
if ((int64_t)str_cnt != expected_str_cnt) { | ||
return Error( | ||
"'" + name_ + "' got unexpected string count " + | ||
std::to_string(str_cnt) + ", expected " + | ||
std::to_string(expected_str_cnt)); | ||
} | ||
} else { | ||
int64_t expected_byte_size = | ||
triton::common::GetByteSize(datatype, shape_); | ||
if ((int64_t)byte_size_ != expected_byte_size) { | ||
return Error( | ||
"'" + name_ + "' got unexpected byte size " + | ||
std::to_string(byte_size_) + ", expected " + | ||
std::to_string(expected_byte_size)); | ||
} | ||
} | ||
} | ||
return Error::Success; | ||
} | ||
|
||
Error | ||
InferInput::PrepareForRequest() | ||
{ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't proto-library used for protobuf<->grpc? Why is it needed for HTTP client?
edit: guessing the requirement is here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any concerns with introducing the new protobuf dependency to the HTTP client, or any alternatives? CC @GuanLuo @tanmayv25