Skip to content

Commit

Permalink
Extend categorical value check to numeric values (#212)
Browse files Browse the repository at this point in the history
* Extend categorical value check to numeric values

Signed-off-by: gaugup <[email protected]>

* Add unit tests

Signed-off-by: gaugup <[email protected]>
  • Loading branch information
gaugup authored Aug 20, 2021
1 parent 1f1ea41 commit 3122320
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
3 changes: 2 additions & 1 deletion dice_ml/explainer_interfaces/explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def check_query_instance_validity(self, features_to_vary, permitted_range, query
raise ValueError("Feature", feature, "not present in training data!")

for feature in self.data_interface.categorical_feature_names:
if query_instance[feature].values[0] not in feature_ranges_orig[feature]:
if query_instance[feature].values[0] not in feature_ranges_orig[feature] and \
str(query_instance[feature].values[0]) not in feature_ranges_orig[feature]:
raise ValueError("Feature", feature, "has a value outside the dataset.")

if feature not in features_to_vary and permitted_range is not None:
Expand Down
7 changes: 5 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest
from collections import OrderedDict
import pandas as pd
import pytest
from sklearn.datasets import load_iris, load_boston
from sklearn.model_selection import train_test_split

import dice_ml
from dice_ml.utils import helpers

Expand Down Expand Up @@ -194,4 +195,6 @@ def create_boston_data():
x_train, x_test, y_train, y_test = train_test_split(
boston.data, boston.target,
test_size=0.2, random_state=7)
return x_train, x_test, y_train, y_test, boston.feature_names
x_train = pd.DataFrame(data=x_train, columns=boston.feature_names)
x_test = pd.DataFrame(data=x_test, columns=boston.feature_names)
return x_train, x_test, y_train, y_test, boston.feature_names.tolist()
35 changes: 31 additions & 4 deletions tests/test_dice_interface/test_explainer_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pandas as pd
import pytest
from sklearn.ensemble import RandomForestRegressor

import dice_ml
from dice_ml.utils.exception import UserConfigValidationException
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase

Expand Down Expand Up @@ -157,16 +159,41 @@ def test_zero_totalcfs(self, desired_class, multi_classification_exp_object, sam

class TestExplainerBaseRegression:

@pytest.mark.parametrize("desired_class, regression_exp_object",
[(1, 'random'), (1, 'genetic'), (1, 'kdtree')],
@pytest.mark.parametrize("desired_range, regression_exp_object",
[([10, 100], 'random'), ([10, 100], 'genetic'), ([10, 100], 'kdtree')],
indirect=['regression_exp_object'])
def test_zero_totalcfs(self, desired_class, regression_exp_object, sample_custom_query_1):
def test_zero_totalcfs(self, desired_range, regression_exp_object, sample_custom_query_1):
exp = regression_exp_object # explainer object
with pytest.raises(UserConfigValidationException):
exp.generate_counterfactuals(
query_instances=[sample_custom_query_1],
total_CFs=0,
desired_class=desired_class)
desired_range=desired_range)

@pytest.mark.parametrize("desired_range, method",
[([10, 100], 'random')])
def test_numeric_categories(self, desired_range, method, create_boston_data):
x_train, x_test, y_train, y_test, feature_names = \
create_boston_data

rfc = RandomForestRegressor(n_estimators=10, max_depth=4,
random_state=777)
model = rfc.fit(x_train, y_train)

dataset_train = x_train.copy()
dataset_train['Outcome'] = y_train
feature_names.remove('CHAS')

d = dice_ml.Data(dataframe=dataset_train, continuous_features=feature_names, outcome_name='Outcome')
m = dice_ml.Model(model=model, backend='sklearn', model_type='regressor')
exp = dice_ml.Dice(d, m, method=method)

cf_explanation = exp.generate_counterfactuals(
query_instances=x_test.iloc[0:1],
total_CFs=10,
desired_range=desired_range)

assert cf_explanation is not None


class TestExplainerBase:
Expand Down

0 comments on commit 3122320

Please sign in to comment.