Skip to content

Commit

Permalink
Expose allow_unknown_field from XGBoost loader (#367)
Browse files Browse the repository at this point in the history
* Expose allow_unknown_field from XGBoost loader

* Test allow_unknown_field in tests

* Update doc

* Fix spelling error

---------

Co-authored-by: William Hicks <[email protected]>
  • Loading branch information
hcho3 and wphicks authored Oct 6, 2023
1 parent 65f0416 commit b0b5482
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 8 deletions.
12 changes: 12 additions & 0 deletions docs/model_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,18 @@ parameters [
]
```

### Allow unknown fields in XGBoost JSON model (`xgboost_allow_unknown_field`)
For XGBoost JSON models, ignore unknown fields instead of throwing a validation
error. This flag is ignored for other kinds of models.
```
parameters [
{
key: "xgboost_allow_unknown_field"
value: { string_value: "true" }
}
]
```

### Decision Threshold
For binary classifiers, it is sometimes helpful to set a specific
confidence threshold for positive decisions. This can be set via the
Expand Down
2 changes: 1 addition & 1 deletion notebooks/faq/FAQs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2425,7 +2425,7 @@
"source": [
"Learning-to-rank models are treated as regression models in the FIL backend. In the configuration file, make sure to set `output_class=\"false\"`.\n",
"\n",
"If the learning-to-rank model was trained with dense data, no extra preperation is needed. As in Example 13.2, you can obtain predictions just like other regression models.\n",
"If the learning-to-rank model was trained with dense data, no extra preparation is needed. As in Example 13.2, you can obtain predictions just like other regression models.\n",
"\n",
"Special care is needed when the learning-to-rank model was trained with sparse data. Since the FIL backend does not yet support a sparse input, we need to use an equivalent dense input instead. Example 13.3 shows how to convert a sparse input into an equivalent dense input. [^](#Table-of-Contents)"
]
Expand Down
4 changes: 4 additions & 0 deletions qa/L0_e2e/generate_example_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,10 @@ def generate_config(
{{
key: "use_experimental_optimizations"
value: {{ string_value: "{use_experimental_optimizations}" }}
}},
{{
key: "xgboost_allow_unknown_field"
value: {{ string_value: "true" }}
}}
]
Expand Down
3 changes: 2 additions & 1 deletion src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ struct RapidsModel : rapids::Model<RapidsSharedState> {
// Load model via Treelite
auto tl_model = std::make_shared<TreeliteModel>(
model_file(), shared_state->model_format(), shared_state->config(),
shared_state->predict_proba(), shared_state->use_herring());
shared_state->predict_proba(), shared_state->use_herring(),
shared_state->xgboost_allow_unknown_field());


if (get_deployment_type() == rapids::GPUDeployment) {
Expand Down
7 changes: 7 additions & 0 deletions src/shared_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct RapidsSharedState : rapids::SharedModelState {
predict_proba_ = get_config_param<bool>("predict_proba", false);
model_format_ = string_to_serialization(
get_config_param<std::string>("model_type", std::string{"xgboost"}));
xgboost_allow_unknown_field_ =
get_config_param<bool>("xgboost_allow_unknown_field", false);
transfer_threshold_ = get_config_param<std::size_t>(
"transfer_threshold", DEFAULT_TRANSFER_THRESHOLD);

Expand All @@ -64,13 +66,18 @@ struct RapidsSharedState : rapids::SharedModelState {

auto predict_proba() const { return predict_proba_; }
auto model_format() const { return model_format_; }
auto xgboost_allow_unknown_field() const
{
return xgboost_allow_unknown_field_;
}
auto transfer_threshold() const { return transfer_threshold_; }
auto config() const { return tl_config_; }
auto use_herring() const { return use_herring_; }

private:
bool predict_proba_{};
SerializationFormat model_format_{};
bool xgboost_allow_unknown_field_{};
std::size_t transfer_threshold_{};
std::shared_ptr<treelite_config> tl_config_ =
std::make_shared<treelite_config>();
Expand Down
8 changes: 5 additions & 3 deletions src/tl_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ struct TreeliteModel {
TreeliteModel(
std::filesystem::path const& model_file, SerializationFormat format,
std::shared_ptr<treelite_config> tl_config, bool predict_proba,
bool use_herring)
bool use_herring, bool xgboost_allow_unknown_field)
: tl_config_{tl_config},
base_tl_model_{[&model_file, &format, predict_proba, this]() {
auto result = load_tl_base_model(model_file, format);
base_tl_model_{[&model_file, &format, predict_proba,
xgboost_allow_unknown_field, this]() {
auto result = load_tl_base_model(
model_file, format, xgboost_allow_unknown_field);
auto num_classes = tl_get_num_classes(*base_tl_model_);
if (!predict_proba && tl_config_->output_class && num_classes > 1) {
std::strcpy(result->param.pred_transform, "max_index");
Expand Down
9 changes: 6 additions & 3 deletions src/tl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace triton { namespace backend { namespace NAMESPACE {

inline auto
load_tl_base_model(
std::filesystem::path const& model_file, SerializationFormat format)
std::filesystem::path const& model_file, SerializationFormat format,
bool xgboost_allow_unknown_field)
{
auto result = std::unique_ptr<treelite::Model>{};

Expand All @@ -42,9 +43,11 @@ load_tl_base_model(
result = treelite::frontend::LoadXGBoostModel(model_file.c_str());
break;
case SerializationFormat::xgboost_json: {
auto config_str = "{}";
auto config_str =
std::string("{\"allow_unknown_field\": ") +
std::string(xgboost_allow_unknown_field ? "true" : "false") + "}";
result = treelite::frontend::LoadXGBoostJSONModel(
model_file.c_str(), config_str);
model_file.c_str(), config_str.c_str());
break;
}
case SerializationFormat::lightgbm:
Expand Down

0 comments on commit b0b5482

Please sign in to comment.