Skip to content

Commit

Permalink
fix: TensorFlow explainers have the same signature as the rest
Browse files Browse the repository at this point in the history
Right now the TensorFlow gradient explainers do not implement
the `_generate_counterfactuals` method from the `BaseExplainer`
class. This means that:
	1) You cannot instantiate object of class
	`DiceTensorFlow(1/2)` without replacing the `__class__`
	of another method, because it is not a valid child
	of `ExplainerBase`.
	2) By overriding the parent `generate_counterfactuals`
	method, the two classes bypass any validation steps
	that would normally be carried out by the base class
	(e.g. checking that the number of CF queries is non-negative).

Signed-off-by: Asen Dotsinski <[email protected]>
  • Loading branch information
asendotsinski committed Oct 29, 2023
1 parent a7b62c4 commit 6a077eb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
22 changes: 11 additions & 11 deletions dice_ml/explainer_interfaces/dice_tensorflow1.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,22 @@ def __init__(self, data_interface, model_interface):
self.loss_weights = [] # yloss_type, diversity_loss_type, feature_weights
self.optimizer_weights = [] # optimizer

def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opposite", proximity_weight=0.5,
diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF",
features_to_vary="all", permitted_range=None, yloss_type="hinge_loss",
diversity_loss_type="dpp_style:inverse_dist", feature_weights="inverse_mad",
optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500, max_iter=5000,
project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False,
init_near_query_instance=True, tie_random=False, stopping_threshold=0.5,
posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000):
def _generate_counterfactuals(self, query_instance, total_CFs,
desired_class="opposite", desired_range=None,
proximity_weight=0.5,
diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF", features_to_vary="all",
permitted_range=None, yloss_type="hinge_loss", diversity_loss_type="dpp_style:inverse_dist",
feature_weights="inverse_mad", optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500,
max_iter=5000, project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False,
init_near_query_instance=True, tie_random=False, stopping_threshold=0.5,
posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000):
"""Generates diverse counterfactual explanations
:param query_instance: Test point of interest. A dictionary of feature names and values or a single row dataframe.
:param total_CFs: Total number of counterfactuals required.
:param desired_class: Desired counterfactual class - can take 0 or 1. Default value is "opposite" to the
outcome class of query_instance for binary classification.
:param desired_range: Not supported currently.
:param proximity_weight: A positive float. Larger this weight, more close the counterfactuals are to the
query_instance.
:param diversity_weight: A positive float. Larger this weight, more diverse the counterfactuals are.
Expand Down Expand Up @@ -159,16 +161,14 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp
loss_diff_thres, loss_converge_maxiter, verbose, init_near_query_instance, tie_random,
stopping_threshold, posthoc_sparsity_param, posthoc_sparsity_algorithm)

counterfactual_explanations = exp.CounterfactualExamples(
return exp.CounterfactualExamples(
data_interface=self.data_interface,
final_cfs_df=final_cfs_df,
test_instance_df=test_instance_df,
final_cfs_df_sparse=final_cfs_df_sparse,
posthoc_sparsity_param=posthoc_sparsity_param,
desired_class=desired_class)

return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations])

def do_cf_initializations(self, total_CFs, algorithm, features_to_vary):
"""Intializes TF variables required for CF generation."""

Expand Down
21 changes: 11 additions & 10 deletions dice_ml/explainer_interfaces/dice_tensorflow2.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,22 @@ def __init__(self, data_interface, model_interface):
self.hyperparameters = [1, 1, 1] # proximity_weight, diversity_weight, categorical_penalty
self.optimizer_weights = [] # optimizer, learning_rate

def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opposite", proximity_weight=0.5,
diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF",
features_to_vary="all", permitted_range=None, yloss_type="hinge_loss",
diversity_loss_type="dpp_style:inverse_dist", feature_weights="inverse_mad",
optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500, max_iter=5000,
project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False,
init_near_query_instance=True, tie_random=False, stopping_threshold=0.5,
posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000):
def _generate_counterfactuals(self, query_instance, total_CFs,
desired_class="opposite", desired_range=None,
proximity_weight=0.5,
diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF", features_to_vary="all",
permitted_range=None, yloss_type="hinge_loss", diversity_loss_type="dpp_style:inverse_dist",
feature_weights="inverse_mad", optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500,
max_iter=5000, project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False,
init_near_query_instance=True, tie_random=False, stopping_threshold=0.5,
posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000):
"""Generates diverse counterfactual explanations
:param query_instance: Test point of interest. A dictionary of feature names and values or a single row dataframe
:param total_CFs: Total number of counterfactuals required.
:param desired_class: Desired counterfactual class - can take 0 or 1. Default value is "opposite" to the
outcome class of query_instance for binary classification.
:param desired_range: Not supported currently.
:param proximity_weight: A positive float. Larger this weight, more close the counterfactuals are to
the query_instance.
:param diversity_weight: A positive float. Larger this weight, more diverse the counterfactuals are.
Expand Down Expand Up @@ -136,15 +138,14 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp
init_near_query_instance, tie_random, stopping_threshold,
posthoc_sparsity_param, posthoc_sparsity_algorithm, limit_steps_ls)

counterfactual_explanations = exp.CounterfactualExamples(
return exp.CounterfactualExamples(
data_interface=self.data_interface,
final_cfs_df=final_cfs_df,
test_instance_df=test_instance_df,
final_cfs_df_sparse=final_cfs_df_sparse,
posthoc_sparsity_param=posthoc_sparsity_param,
desired_class=desired_class)

return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations])

def predict_fn(self, input_instance):
"""prediction function"""
Expand Down

0 comments on commit 6a077eb

Please sign in to comment.