diff --git a/python/interpret-core/interpret/glassbox/_ebm/_merge_ebms.py b/python/interpret-core/interpret/glassbox/_ebm/_merge_ebms.py index 22eb63ec7..5cfe2316a 100644 --- a/python/interpret-core/interpret/glassbox/_ebm/_merge_ebms.py +++ b/python/interpret-core/interpret/glassbox/_ebm/_merge_ebms.py @@ -3,12 +3,24 @@ import logging import warnings +from collections.abc import Sequence +from functools import reduce from itertools import chain, count -from math import isnan +from math import isnan, prod +from typing import List, Type import numpy as np from ...utils._native import Native +from ._ebm import ( + DPExplainableBoostingClassifier, + DPExplainableBoostingRegressor, + EBMModel, + ExplainableBoostingClassifier, + ExplainableBoostingRegressor, + _clean_exclude, + is_private, +) from ._utils import ( convert_categorical_to_continuous, deduplicate_bins, @@ -48,15 +60,15 @@ def _harmonize_tensor( # greater than the old model's lowest cut. # eg: new: | | | | | # old: | | - # other1: | | proprotion | + # other1: | | proportion | # other2: | proportion | # One wrinkle is that for pairs, we'll be using the pair cuts and we need to - # one-dimensionalize any existing pair weights onto their respective 1D axies - # before proportionating them. Annother issue is that we might not even have + # one-dimensionalize any existing pair weights onto their respective 1D axes + # before proportionating them. Another issue is that we might not even have # another term_feature that uses some particular feature that we use in our model # so we don't have any weights. We can solve that issue by dropping any feature's # bins for terms that we have no information for. After we do this we'll have - # guaranteed that we only have new bin cuts for feature axies that we have inside + # guaranteed that we only have new bin cuts for feature axes that we have inside # the bin level that we're handling! old_feature_idxs = list(old_feature_idxs) @@ -156,33 +168,37 @@ def _harmonize_tensor( percentage.append(1.0) else: for new_idx_minus_one, old_idx in enumerate(lookup): - if new_idx_minus_one == 0: - new_low = new_bounds[feature_idx, 0] - # TODO: if nan OR out of bounds from the cuts, estimate it. - # If -inf or +inf, change it to min/max for float - else: - new_low = new_feature_bins[new_idx_minus_one - 1] + # TODO: if nan OR out of bounds from the cuts, estimate it. + # If -inf or +inf, change it to min/max for float + new_low = ( + new_bounds[feature_idx, 0] + if new_idx_minus_one == 0 + else new_feature_bins[new_idx_minus_one - 1] + ) - if len(new_feature_bins) <= new_idx_minus_one: - new_high = new_bounds[feature_idx, 1] - # TODO: if nan OR out of bounds from the cuts, estimate it. - # If -inf or +inf, change it to min/max for float - else: - new_high = new_feature_bins[new_idx_minus_one] + # TODO: if nan OR out of bounds from the cuts, estimate it. + # If -inf or +inf, change it to min/max for float + new_high = ( + new_bounds[feature_idx, 1] + if len(new_feature_bins) <= new_idx_minus_one + else new_feature_bins[new_idx_minus_one] + ) - if old_idx == 1: - old_low = old_bounds[feature_idx, 0] - # TODO: if nan OR out of bounds from the cuts, estimate it. - # If -inf or +inf, change it to min/max for float - else: - old_low = old_feature_bins[old_idx - 2] + # TODO: if nan OR out of bounds from the cuts, estimate it. + # If -inf or +inf, change it to min/max for float + old_low = ( + old_bounds[feature_idx, 0] + if old_idx == 1 + else old_feature_bins[old_idx - 2] + ) - if len(old_feature_bins) < old_idx: - old_high = old_bounds[feature_idx, 1] - # TODO: if nan OR out of bounds from the cuts, estimate it. - # If -inf or +inf, change it to min/max for float - else: - old_high = old_feature_bins[old_idx - 1] + # TODO: if nan OR out of bounds from the cuts, estimate it. + # If -inf or +inf, change it to min/max for float + old_high = ( + old_bounds[feature_idx, 1] + if len(old_feature_bins) < old_idx + else old_feature_bins[old_idx - 1] + ) if old_high <= new_low or new_high <= old_low: # if there are bins in the area above where the old data extended, then @@ -242,7 +258,7 @@ def _harmonize_tensor( map_bins[bin_idx] for map_bins, bin_idx in zip(mapping, old_reversed_bin_idxs) ] - n_cells2 = np.prod([len(x) for x in cell_map]) + n_cells2 = prod(map(len, cell_map)) val = 0 if n_multiclasses == 1 else np.zeros(n_multiclasses, np.float64) total_weight = 0.0 for cell2_idx in range(n_cells2): @@ -278,156 +294,56 @@ def _harmonize_tensor( return new_tensor.reshape(new_shape) -def merge_ebms(models): - """Merges EBM models trained on similar datasets that have the same set of features. - - Args: - models: List of EBM models to be merged. - - Returns: - An EBM model with averaged mean and standard deviation of input models. - """ - - if len(models) == 0: # pragma: no cover - msg = "0 models to merge." - raise Exception(msg) - - model_types = list(set(map(type, models))) - if len(model_types) == 2: - type_names = [model_type.__name__ for model_type in model_types] - if ( - "ExplainableBoostingClassifier" in type_names - and "DPExplainableBoostingClassifier" in type_names - ): - ebm_type = model_types[type_names.index("ExplainableBoostingClassifier")] - is_classification = True - is_dp = False - elif ( - "ExplainableBoostingRegressor" in type_names - and "DPExplainableBoostingRegressor" in type_names - ): - ebm_type = model_types[type_names.index("ExplainableBoostingRegressor")] - is_classification = False - is_dp = False - else: - msg = "Inconsistent model types attempting to be merged." - raise Exception(msg) - elif len(model_types) == 1: - ebm_type = model_types[0] - if ebm_type.__name__ == "ExplainableBoostingClassifier": - is_classification = True - is_dp = False - elif ebm_type.__name__ == "DPExplainableBoostingClassifier": - is_classification = True - is_dp = True - elif ebm_type.__name__ == "ExplainableBoostingRegressor": - is_classification = False - is_dp = False - elif ebm_type.__name__ == "DPExplainableBoostingRegressor": - is_classification = False - is_dp = True - else: - msg = f"Invalid EBM model type {ebm_type.__name__} attempting to be merged." - raise Exception(msg) - else: - msg = "Inconsistent model types being merged." - raise Exception(msg) - - # TODO: create the ExplainableBoostingClassifier etc, type directly - # by name instead of using __new__ from ebm_type - ebm = ebm_type.__new__(ebm_type) - - if any( - not getattr(model, "has_fitted_", False) for model in models - ): # pragma: no cover - msg = "All models must be fitted." - raise Exception(msg) - ebm.has_fitted_ = True - - link = models[0].link_ - if any(model.link_ != link for model in models): - msg = "Models with different link functions cannot be merged" - raise Exception(msg) - ebm.link_ = link - - link_param = models[0].link_param_ - if isnan(link_param): - if not all(isnan(model.link_param_) for model in models): - msg = "Models with different link param values cannot be merged" - raise Exception(msg) - elif any(model.link_param_ != link_param for model in models): - msg = "Models with different link param values cannot be merged" - raise Exception(msg) - ebm.link_param_ = link_param - +def _assert_model_compatibility(models: List[EBMModel]) -> None: + """Check if models can be merged, raise error if not.""" # self.bins_ is the only feature based attribute that we absolutely require n_features = len(models[0].bins_) - for model in models: if n_features != len(model.bins_): # pragma: no cover msg = "Inconsistent numbers of features in the models." raise Exception(msg) - feature_names_in = getattr(model, "feature_names_in_", None) - if feature_names_in is not None: - if n_features != len(feature_names_in): # pragma: no cover - msg = "Inconsistent numbers of features in the models." - raise Exception(msg) - - feature_types_in = getattr(model, "feature_types_in_", None) - if feature_types_in is not None: - if n_features != len(feature_types_in): # pragma: no cover - msg = "Inconsistent numbers of features in the models." - raise Exception(msg) - - feature_bounds = getattr(model, "feature_bounds_", None) - if feature_bounds is not None: - if n_features != feature_bounds.shape[0]: # pragma: no cover - msg = "Inconsistent numbers of features in the models." - raise Exception(msg) - - histogram_weights = getattr(model, "histogram_weights_", None) - if histogram_weights is not None: - if n_features != len(histogram_weights): # pragma: no cover - msg = "Inconsistent numbers of features in the models." - raise Exception(msg) + if hasattr(model, "feature_names_in_") and n_features != len( + model.feature_names_in_ + ): # pragma: no cover + msg = "Inconsistent numbers of features in the models." + raise Exception(msg) - unique_val_counts = getattr(model, "unique_val_counts_", None) - if unique_val_counts is not None: - if n_features != len(unique_val_counts): # pragma: no cover - msg = "Inconsistent numbers of features in the models." - raise Exception(msg) + if hasattr(model, "feature_types_in_") and n_features != len( + model.feature_types_in_ + ): # pragma: no cover + msg = "Inconsistent numbers of features in the models." + raise Exception(msg) - old_bounds = [] - old_mapping = [] - old_bins = [] - for model in models: - if any(len(set(map(type, bin_levels))) != 1 for bin_levels in model.bins_): - msg = "Inconsistent bin types within a model." + if ( + hasattr(model, "feature_bounds_") + and n_features != model.feature_bounds_.shape[0] + ): # pragma: no cover + msg = "Inconsistent numbers of features in the models." raise Exception(msg) - feature_bounds = getattr(model, "feature_bounds_", None) - if feature_bounds is None: - old_bounds.append(None) - else: - old_bounds.append(feature_bounds.copy()) + if hasattr(model, "histogram_weights_") and n_features != len( + model.histogram_weights_ + ): # pragma: no cover + msg = "Inconsistent numbers of features in the models." + raise Exception(msg) - old_mapping.append([[] for _ in range(n_features)]) - old_bins.append([[] for _ in range(n_features)]) + if hasattr(model, "unique_val_counts_") and n_features != len( + model.unique_val_counts_ + ): # pragma: no cover + msg = "Inconsistent numbers of features in the models." + raise Exception(msg) - # TODO: every time we merge models we fragment the bins more and more and this is undesirable - # especially for pairs. When we build models, we store the feature bin cuts for pairs even - # if we have no pairs that use that paritcular feature as a pair. We can eliminate these useless - # pair feature cuts before merging the bins and that'll give us less resulting cuts. Having less - # cuts reduces the number of estimates that we need to make and reduces the complexity of the - # tensors, so it's good to have this reduction. +def _get_new_bins(models: List[EBMModel], *, old_mapping, old_bins, old_bounds): + n_features = len(models[0].bins_) new_feature_types = [] new_bins = [] for feature_idx in range(n_features): bin_types = {type(model.bins_[feature_idx][0]) for model in models} + is_categorical = len(bin_types) == 1 and next(iter(bin_types)) is dict - if len(bin_types) == 1 and next(iter(bin_types)) is dict: + if is_categorical: # categorical new_feature_type = None for model in models: @@ -440,8 +356,7 @@ def merge_ebms(models): new_feature_type = "ordinal" if new_feature_type is None: new_feature_type = "nominal" - else: - # continuous + else: # continuous if any(bin_type not in {dict, np.ndarray} for bin_type in bin_types): msg = "Invalid bin type." raise Exception(msg) @@ -460,8 +375,7 @@ def merge_ebms(models): old_mapping[model_idx][feature_idx].append(None) old_bins[model_idx][feature_idx].append(bin_level) - if len(bin_types) == 1 and next(iter(bin_types)) is dict: - # categorical + if is_categorical: merged_keys = sorted( set(chain.from_iterable(bin.keys() for bin in model_bins)) ) @@ -471,12 +385,10 @@ def merge_ebms(models): # order and also handling merged categories (where two categories map to a single score) # We should first try to progress in order along each set of keys and see if we can # establish the perfect order which might work if there are isolated missing categories - # and if we can't get a unique guaranteed sorted order that way then examime all the + # and if we can't get a unique guaranteed sorted order that way then examine all the # different known sort order and figure out if any of the possible orderings match merged_bins = dict(zip(merged_keys, count(1))) - else: - # continuous - + else: # continuous if len(bin_types) != 1: # We have both categorical and continuous. We can't convert continuous # to categorical since we lack the original labels, but we can convert @@ -506,11 +418,176 @@ def merge_ebms(models): old_bins[model_idx][feature_idx][level_idx] = converted_bins old_mapping[model_idx][feature_idx][level_idx] = mapping - merged_bins = np.array( - sorted(set(chain.from_iterable(model_bins))), np.float64 - ) + merged_bins = reduce(np.union1d, model_bins) new_leveled_bins.append(merged_bins) new_bins.append(new_leveled_bins) + return new_bins, new_feature_types + + +def _initialize_ebm(models: List[EBMModel], ebm_type: Type[EBMModel]) -> EBMModel: + """Fully initialize new model from existing `models`.""" + weights = np.fromiter( + (np.sum(model.bag_weights_) for model in models), dtype=np.float64 + ) + kdws = models[0].get_params() + # treated manually + for key in [ + "feature_names", + "feature_types", + "objective", + "exclude", + "interactions", + "monotone_constraints", + ]: + del kdws[key] + manual_kdws = {} + + # handle `exclude` + if all(getattr(model, "exclude", None) is not None for model in models): + # none of the models contains all feature_idxs + # merged EBM should exclude features included by none of the models + # -> overlap of all features + clean_excludes = [] + for model in models: + if model.exclude == "mains": + clean_excludes.append({(idx,) for idx in range(model.n_features_in_)}) + continue + feature_map = { + name: idx for idx, name in enumerate(model.feature_names_in_) + } + clean_excludes.append(_clean_exclude(model.exclude, feature_map)) + excluded = set.intersection(*clean_excludes) + manual_kdws["exclude"] = list(excluded) if excluded else None + + # handle `interactions` + if all( + isinstance(getattr(model, "interactions", None), Sequence) for model in models + ): + # merge all interactions, use sorted for deterministic outcome + manual_kdws["interactions"] = sorted( + {term for model in models for term in model.interactions} + ) + else: # convert all + + def get_float_interactions(model: EBMModel) -> float: + interactions = model.interactions + if isinstance(interactions, Sequence): + interactions = len(interactions) + if isinstance(interactions, int): + interactions /= model.n_features_in_ + return interactions + + values = [get_float_interactions(model) for model in models] + manual_kdws["interactions"] = np.average(values, weights=weights) + + # handle `monotone_constraints` + if all( + getattr(model, "monotone_constraints", None) is not None for model in models + ): + # if all models apply monotonicity constrains we can validate them + def monotone(args) -> int: + if all(val == +1 for val in args): + return +1 + if all(val == -1 for val in args): + return -1 + return 0 + + manual_kdws["monotone_constraints"] = [ + monotone(item) + for item in zip(*(model.monotone_constraints for model in models)) + ] + + for key in kdws: + values = np.array([getattr(model, key, np.nan) for model in models]) + nan_weight = np.copy(weights) + nan_weight[np.isnan(values)] = 0 + kdws[key] = np.average(values, weights=nan_weight) + return ebm_type(**kdws, **manual_kdws) + + +def _get_model_type(models: List[EBMModel]) -> Type[EBMModel]: + model_types = {type(mod) for mod in models} + if len(model_types) == 1: + return model_types.pop() + if len(model_types) == 2: + if model_types == { + ExplainableBoostingClassifier, + DPExplainableBoostingClassifier, + }: + return ExplainableBoostingClassifier + if model_types == { + ExplainableBoostingRegressor, + DPExplainableBoostingRegressor, + }: + return ExplainableBoostingRegressor + msg = f"Inconsistent model types attempting to be merged: {model_types}." + raise ValueError(msg) + + +def merge_ebms(models): + """Merge EBM models trained on similar datasets that have the same set of features. + + Args: + models: List of EBM models to be merged. + + Returns: + An EBM model with averaged mean and standard deviation of input models. + """ + if len(models) == 0: # pragma: no cover + msg = "0 models to merge." + raise Exception(msg) + + ebm_type = _get_model_type(models) + + if any(not getattr(model, "has_fitted_", False) for model in models): + msg = "All models must be fitted." + raise Exception(msg) + + ebm = _initialize_ebm(models, ebm_type=ebm_type) + ebm.has_fitted_ = True + + link = models[0].link_ + if any(model.link_ != link for model in models): + msg = "Models with different link functions cannot be merged" + raise Exception(msg) + ebm.link_ = link + + link_param = models[0].link_param_ + if isnan(link_param): + if not all(isnan(model.link_param_) for model in models): + msg = "Models with different link param values cannot be merged" + raise Exception(msg) + elif any(model.link_param_ != link_param for model in models): + msg = "Models with different link param values cannot be merged" + raise Exception(msg) + ebm.link_param_ = link_param + + # self.bins_ is the only feature based attribute that we absolutely require + n_features = len(models[0].bins_) + + _assert_model_compatibility(models) + + old_mapping = [[[] for _ in range(n_features)] for _ in models] + old_bins = [[[] for _ in range(n_features)] for _ in models] + old_bounds = [] + for model in models: + if any(len(set(map(type, bin_levels))) != 1 for bin_levels in model.bins_): + msg = "Inconsistent bin types within a model." + raise Exception(msg) + + feature_bounds = getattr(model, "feature_bounds_", None) + old_bounds.append(None if feature_bounds is None else feature_bounds.copy()) + + # TODO: every time we merge models we fragment the bins more and more and this is undesirable + # especially for pairs. When we build models, we store the feature bin cuts for pairs even + # if we have no pairs that use that particular feature as a pair. We can eliminate these useless + # pair feature cuts before merging the bins and that'll give us less resulting cuts. Having less + # cuts reduces the number of estimates that we need to make and reduces the complexity of the + # tensors, so it's good to have this reduction. + + new_bins, new_feature_types = _get_new_bins( + models, old_mapping=old_mapping, old_bins=old_bins, old_bounds=old_bounds + ) ebm.feature_types_in_ = new_feature_types deduplicate_bins(new_bins) ebm.bins_ = new_bins @@ -545,19 +622,21 @@ def merge_ebms(models): list(zip(min_feature_vals, max_feature_vals)), np.float64 ) - if not is_dp: - if all( + if ( + not is_private(ebm_type) + and hasattr(ebm, "feature_bounds_") + and all( hasattr(model, "histogram_weights_") and hasattr(model, "feature_bounds_") for model in models - ): - if hasattr(ebm, "feature_bounds_"): - # TODO: estimate the histogram bin counts by taking the min of the mins and the max of the maxes - # and re-apportioning the counts based on the distributions of the previous histograms. Proprotion - # them to the floor of their counts and then assign any remaining integers based on how much - # they reduce the RMSE of the integer counts from the ideal floating point counts. - pass - - if is_classification: + ) + ): + # TODO: estimate the histogram bin counts by taking the min of the mins and the max of the maxes + # and re-apportioning the counts based on the distributions of the previous histograms. Proportion + # them to the floor of their counts and then assign any remaining integers based on how much + # they reduce the RMSE of the integer counts from the ideal floating point counts. + pass + + if ebm_type in (ExplainableBoostingClassifier, DPExplainableBoostingClassifier): ebm.classes_ = models[0].classes_.copy() if any(not np.array_equal(ebm.classes_, model.classes_) for model in models): # pragma: no cover @@ -625,7 +704,7 @@ def merge_ebms(models): # TODO: in the future we might at this point try and figure out the most # common feature ordering within the terms. Take the mode first - # and amonst the orderings that tie, choose the one that's best sorted by + # and amongst the orderings that tie, choose the one that's best sorted by # feature indexes ebm.term_features_ = sorted_fgs @@ -636,26 +715,26 @@ def merge_ebms(models): # interaction mismatches where an interaction will be in one model, but not the other. # We need to estimate the bin_weight_ tensors that would have existed in this case. # We'll use the interaction terms that we do have in other models to estimate the - # distribution in the essense of the data, which should be roughly consistent or you + # distribution in the essence of the data, which should be roughly consistent or you # shouldn't be attempting to merge the models in the first place. We'll then scale - # the percentage distribution by the total weight of the model that we're fillin in the + # the percentage distribution by the total weight of the model that we're filling in the # details for. # TODO: this algorithm has some problems. The estimated tensor that we get by taking the # model weight and distributing it by a per-cell percentage measure means that we get - # inconsistent weight distibutions along the axis. We can take our resulting weight tensor + # inconsistent weight distributions along the axis. We can take our resulting weight tensor # and sum the columns/rows to get the weights on each individual feature axis. Our model # however comes with a known set of weights on each feature, and the result of our operation # will not match the existing distribution in almost all cases. I think there might be # some algorithm where we start with the per-feature weights and use the distribution hints # from the other models to inform where we place our exact weights that we know about in our - # model from each axis. The problem is that the sums in both axies need to agree, and each + # model from each axis. The problem is that the sums in both axes need to agree, and each # change we make influences both. I'm not sure we can even guarantee that there is an answer # and if there was one I'm not sure how we'd go about generating it. I'm going to leave # this problem for YOU: a future person who is smarter than me and has more time to solve this. # One hint: I think a possible place to start would be an iterative algorithm that's similar # to purification where you randomly select a row/column and try to get closer at each step - # to the rigth answer. Good luck! + # to the right answer. Good luck! # # Oh, there's also another deeper problem.. let's say you had a crazy 5 way interaction in the # model eg: (0,1,2,3,4) and you had 2 and 3 way interactions that either overlap or not. @@ -698,9 +777,8 @@ def merge_ebms(models): count(), models, fg_dicts, model_weights ): n_outer_bags = -1 - if hasattr(model, "bagged_scores_"): - if len(model.bagged_scores_) > 0: - n_outer_bags = len(model.bagged_scores_[0]) + if hasattr(model, "bagged_scores_") and len(model.bagged_scores_) > 0: + n_outer_bags = len(model.bagged_scores_[0]) term_idx = fg_dict.get(sorted_fg) if term_idx is None: @@ -772,7 +850,7 @@ def merge_ebms(models): # removing the higher order terms might allow us to eliminate some extra bins now that couldn't before deduplicate_bins(ebm.bins_) - # dependent attributes (can be re-derrived after serialization) + # dependent attributes (can be re-derived after serialization) ebm.n_features_in_ = len(ebm.bins_) # scikit-learn specified name ebm.term_names_ = generate_term_names(ebm.feature_names_in_, ebm.term_features_) diff --git a/python/interpret-core/tests/glassbox/ebm/test_merge_ebms.py b/python/interpret-core/tests/glassbox/ebm/test_merge_ebms.py index 9883596c9..449c0c1de 100644 --- a/python/interpret-core/tests/glassbox/ebm/test_merge_ebms.py +++ b/python/interpret-core/tests/glassbox/ebm/test_merge_ebms.py @@ -2,9 +2,15 @@ # Distributed under the MIT software license import warnings +from functools import partial import numpy as np -from interpret.glassbox import ExplainableBoostingClassifier, merge_ebms +import pytest +from interpret.glassbox import ( + ExplainableBoostingClassifier, + ExplainableBoostingRegressor, + merge_ebms, +) from interpret.utils import make_synthetic from sklearn.model_selection import train_test_split @@ -13,8 +19,16 @@ smoke_test_explanations, ) +# arguments for faster fitting time to reduce test time +# we want to test the interface, not get good results +_fast_kwds = { + "outer_bags": 2, + "max_rounds": 100, +} + def valid_ebm(ebm): + assert repr(ebm), "Cannot represent EBM which is important for debugging" assert ebm.term_features_[0] == (0,) for term_scores in ebm.term_scores_: @@ -23,12 +37,6 @@ def valid_ebm(ebm): def test_merge_ebms(): - # TODO: improve this test by checking the merged ebms for validity. - # Right now the merged ebms fail the check for valid_ebm. - # The failure might be related to the warning we're getting - # about the scalar divide in the merge_ebms line: - # "percentage.append((new_high - new_low) / (old_high - old_low))" - X, y, names, _ = make_synthetic(classes=2, missing=True, output_type="str") with warnings.catch_warnings(): @@ -57,6 +65,7 @@ def test_merge_ebms(): max_bins=10, max_interaction_bins=5, interactions=[(8, 3, 0)], + **_fast_kwds, ) ebm1.fit(X_train, y_train) @@ -70,6 +79,7 @@ def test_merge_ebms(): max_bins=11, max_interaction_bins=4, interactions=[(8, 2), (7, 3), (1, 2)], + **_fast_kwds, ) ebm2.fit(X_train, y_train) @@ -83,11 +93,12 @@ def test_merge_ebms(): max_bins=12, max_interaction_bins=3, interactions=[(1, 2), (2, 8)], + **_fast_kwds, ) ebm3.fit(X_train, y_train) merged_ebm1 = merge_ebms([ebm1, ebm2, ebm3]) - # valid_ebm(merged_ebm1) + valid_ebm(merged_ebm1) global_exp = merged_ebm1.explain_global() local_exp = merged_ebm1.explain_local(X[:5, :], y[:5]) smoke_test_explanations(global_exp, local_exp, 6000) @@ -102,11 +113,12 @@ def test_merge_ebms(): max_bins=13, max_interaction_bins=8, interactions=2, + **_fast_kwds, ) ebm4.fit(X_train, y_train) merged_ebm2 = merge_ebms([merged_ebm1, ebm4]) - # valid_ebm(merged_ebm2) + valid_ebm(merged_ebm2) global_exp = merged_ebm2.explain_global() local_exp = merged_ebm2.explain_local(X[:5, :], y[:5]) smoke_test_explanations(global_exp, local_exp, 6000) @@ -121,11 +133,12 @@ def test_merge_ebms(): max_bins=14, max_interaction_bins=8, interactions=2, + **_fast_kwds, ) ebm5.fit(X_train, y_train) merged_ebm3 = merge_ebms([ebm5, merged_ebm2]) - # valid_ebm(merged_ebm3) + valid_ebm(merged_ebm3) global_exp = merged_ebm3.explain_global() local_exp = merged_ebm3.explain_local(X[:5, :], y[:5]) smoke_test_explanations(global_exp, local_exp, 6000) @@ -146,6 +159,7 @@ def test_merge_ebms_multiclass(): random_state=random_state, interactions=0, max_bins=10, + **_fast_kwds, ) ebm1.fit(X_train, y_train) @@ -157,6 +171,7 @@ def test_merge_ebms_multiclass(): random_state=random_state, interactions=0, max_bins=11, + **_fast_kwds, ) ebm2.fit(X_train, y_train) @@ -168,6 +183,7 @@ def test_merge_ebms_multiclass(): random_state=random_state, interactions=0, max_bins=12, + **_fast_kwds, ) ebm3.fit(X_train, y_train) @@ -182,7 +198,10 @@ def test_merge_ebms_multiclass(): X, y, test_size=0.10, random_state=random_state ) ebm4 = ExplainableBoostingClassifier( - random_state=random_state, interactions=0, max_bins=13 + random_state=random_state, + interactions=0, + max_bins=13, + # **_fast_kwds, ) ebm4.fit(X_train, y_train) @@ -197,7 +216,10 @@ def test_merge_ebms_multiclass(): X, y, test_size=0.50, random_state=random_state ) ebm5 = ExplainableBoostingClassifier( - random_state=random_state, interactions=0, max_bins=14 + random_state=random_state, + interactions=0, + max_bins=14, + # **_fast_kwds, ) ebm5.fit(X_train, y_train) @@ -206,3 +228,76 @@ def test_merge_ebms_multiclass(): global_exp = merged_ebm3.explain_global() local_exp = merged_ebm3.explain_local(X_te, y_te) smoke_test_explanations(global_exp, local_exp, 6000) + + +@pytest.mark.filterwarnings("ignore:Missing values detected.:UserWarning") +def test_unfitted(): + """To merge EBMs, all have to be fitted.""" + X, y, names, _ = make_synthetic(classes=2, missing=True, output_type="str") + TestEBM = partial( + ExplainableBoostingClassifier, + feature_names=names, + random_state=42, + **_fast_kwds, + ) + ebm1 = TestEBM() + ebm1.fit(X, y) + ebm2 = TestEBM() + # ebm2 is not fitted + with pytest.raises(Exception, match="All models must be fitted."): + merge_ebms([ebm1, ebm2]) + + +@pytest.mark.filterwarnings("ignore:Missing values detected.:UserWarning") +def test_merge_monotone(): + """Check merging of features with `monotone_constraints`.""" + X, y, names, _ = make_synthetic(classes=None, missing=True, output_type="str") + TestEBM = partial( + ExplainableBoostingRegressor, + feature_names=names, + random_state=42, + **_fast_kwds, + ) + # feature 3, 6 are truly monotonous increasing, 7 has no impact + ebm1 = TestEBM(monotone_constraints=[0, 0, 0, +1, 0, 0, +1, +1, 0, 0]) + ebm1.fit(X, y) + ebm2 = TestEBM(monotone_constraints=[0, 0, 0, +1, 0, 0, +0, -1, 0, 0]) + ebm2.fit(X, y) + merged_ebm = merge_ebms([ebm1, ebm2]) + assert merged_ebm.monotone_constraints == [0, 0, 0, +1, 0, 0, +0, +0, 0, 0] + merged_ebm = merge_ebms([ebm2, ebm2]) + assert merged_ebm.monotone_constraints == [0, 0, 0, +1, 0, 0, +0, -1, 0, 0] + ebm3 = TestEBM(monotone_constraints=None) + ebm3.fit(X, y) + merged_ebm = merge_ebms([ebm1, ebm2, ebm3]) + assert merged_ebm.monotone_constraints is None + + +@pytest.mark.filterwarnings("ignore:Missing values detected.:UserWarning") +def test_merge_exclude(): + """Check merging of features with `exclude`.""" + X, y, names, _ = make_synthetic(classes=2, missing=True, output_type="str") + TestEBM = partial( + ExplainableBoostingClassifier, + feature_names=names, + random_state=42, + **_fast_kwds, + ) + ebm1 = TestEBM(exclude=None) + ebm1.fit(X, y) + ebm2 = TestEBM(exclude=[0, 1, 2]) + ebm2.fit(X, y) + merged_ebm = merge_ebms([ebm1, ebm2]) + assert merged_ebm.exclude is None + ebm1 = TestEBM(exclude=[0, 2]) + ebm1.fit(X, y) + ebm2 = TestEBM(exclude=[0, 1, 2]) + ebm2.fit(X, y) + merged_ebm = merge_ebms([ebm1, ebm2]) + assert merged_ebm.exclude == [(0,), (2,)] + ebm1 = TestEBM(exclude="mains") + ebm1.fit(X, y) + ebm2 = TestEBM(exclude=[0, 1, 2]) + ebm2.fit(X, y) + merged_ebm = merge_ebms([ebm1, ebm2]) + assert merged_ebm.exclude == [(0,), (1,), (2,)]