From ccd5db5f2931c39e8aa0eb0bba6b53868e357d5a Mon Sep 17 00:00:00 2001 From: Paul Koch Date: Thu, 2 Jan 2025 20:40:02 -0800 Subject: [PATCH] add cat_penalty develop parameter --- python/interpret-core/interpret/develop.py | 1 + python/interpret-core/interpret/glassbox/_ebm/_boost.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/python/interpret-core/interpret/develop.py b/python/interpret-core/interpret/develop.py index ff430728c..20bf08bbf 100644 --- a/python/interpret-core/interpret/develop.py +++ b/python/interpret-core/interpret/develop.py @@ -21,6 +21,7 @@ "cat_smooth": math.inf, # math.inf means use only the gradient for sorting "max_cat_threshold": 9223372036854775807, "cat_include": 0.75, + "cat_penalty": 0.0, "purify_boosting": False, "purify_result": False, "randomize_initial_feature_order": True, diff --git a/python/interpret-core/interpret/glassbox/_ebm/_boost.py b/python/interpret-core/interpret/glassbox/_ebm/_boost.py index 8d49fd9aa..37da2da53 100644 --- a/python/interpret-core/interpret/glassbox/_ebm/_boost.py +++ b/python/interpret-core/interpret/glassbox/_ebm/_boost.py @@ -199,6 +199,11 @@ def boost( max_leaves=max_leaves, monotone_constraints=term_monotone, ) + + if contains_nominals and len(term_features[term_idx]) == 1: + # penalize nominals a bit because they benefit from sorting categories + avg_gain *= 1.0 - develop.get_option("cat_penalty") + gainkey = (-avg_gain, native.generate_seed(rng), term_idx) if not make_progress: if bestkey is None or gainkey < bestkey: