Skip to content

Commit

Permalink
fix randomized dropout for categories and implement max_cat_threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbkoch committed Jan 2, 2025
1 parent aac748e commit e3aa1b2
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/interpret-core/interpret/develop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"min_cat_samples": 12,
"min_cat_hessian_percent": 0.0,
"cat_smooth": math.inf, # math.inf means use only the gradient for sorting
"max_cat_threshold": 32,
"max_cat_threshold": 9223372036854775807,
"cat_include": 0.75,
"purify_boosting": False,
"purify_result": False,
Expand Down
8 changes: 5 additions & 3 deletions shared/libebm/PartitionOneDimensionalBoosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -982,8 +982,6 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
// TODO: use all of these!
UNUSED(bUnseen);
UNUSED(categoryHessianPercentMin);
UNUSED(categoricalThresholdMax);
UNUSED(categoricalInclusionPercent);

BoosterCore* const pBoosterCore = pBoosterShell->GetBoosterCore();
const size_t cScores = GET_COUNT_SCORES(cCompilerScores, pBoosterCore->GetCountScores());
Expand Down Expand Up @@ -1159,14 +1157,18 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
if(cRemaining <= cKeep && categoricalInclusionPercent < 1.0) {
cKeep = cRemaining - 1;
}
if(categoricalThresholdMax < cKeep) {
cKeep = categoricalThresholdMax;
}
if(cKeep <= 1) {
cKeep = 2;
}
if(cRemaining < cKeep) {
cKeep = cRemaining;
}
EBM_ASSERT(2 <= cKeep);

const bool bShuffle = 1 != cCompilerScores || std::isnan(categoricalSmoothing);
const bool bShuffle = 1 != cCompilerScores || std::isnan(categoricalSmoothing) || cKeep != cRemaining;
const bool bSort = 1 == cCompilerScores && !std::isnan(categoricalSmoothing);

EBM_ASSERT(bShuffle || bSort);
Expand Down
2 changes: 1 addition & 1 deletion shared/libebm/tests/boosting_unusual_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2379,7 +2379,7 @@ static double RandomizedTesting(const AccelerationFlags acceleration) {
}

TEST_CASE("stress test, boosting") {
const double expected = 12442461586398.865;
const double expected = 12554194225282.529;

double validationMetricExact = RandomizedTesting(AccelerationFlags_NONE);
CHECK(validationMetricExact == expected);
Expand Down

0 comments on commit e3aa1b2

Please sign in to comment.