Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: TensorFlow explainers have the same signature as the rest #415

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions dice_ml/explainer_interfaces/dice_tensorflow1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import tensorflow as tf

from dice_ml import diverse_counterfactuals as exp
from dice_ml.counterfactual_explanations import CounterfactualExplanations
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase


Expand Down Expand Up @@ -61,20 +60,22 @@ def __init__(self, data_interface, model_interface):
self.loss_weights = [] # yloss_type, diversity_loss_type, feature_weights
self.optimizer_weights = [] # optimizer

def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opposite", proximity_weight=0.5,
diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF",
features_to_vary="all", permitted_range=None, yloss_type="hinge_loss",
diversity_loss_type="dpp_style:inverse_dist", feature_weights="inverse_mad",
optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500, max_iter=5000,
project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False,
init_near_query_instance=True, tie_random=False, stopping_threshold=0.5,
posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000):
def _generate_counterfactuals(self, query_instance, total_CFs,
desired_class="opposite", desired_range=None,
proximity_weight=0.5,
diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF", features_to_vary="all",
permitted_range=None, yloss_type="hinge_loss", diversity_loss_type="dpp_style:inverse_dist",
feature_weights="inverse_mad", optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500,
max_iter=5000, project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False,
init_near_query_instance=True, tie_random=False, stopping_threshold=0.5,
posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000):
"""Generates diverse counterfactual explanations

:param query_instance: Test point of interest. A dictionary of feature names and values or a single row dataframe.
:param total_CFs: Total number of counterfactuals required.
:param desired_class: Desired counterfactual class - can take 0 or 1. Default value is "opposite" to the
outcome class of query_instance for binary classification.
:param desired_range: Not supported currently.
:param proximity_weight: A positive float. Larger this weight, more close the counterfactuals are to the
query_instance.
:param diversity_weight: A positive float. Larger this weight, more diverse the counterfactuals are.
Expand Down Expand Up @@ -159,16 +160,14 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp
loss_diff_thres, loss_converge_maxiter, verbose, init_near_query_instance, tie_random,
stopping_threshold, posthoc_sparsity_param, posthoc_sparsity_algorithm)

counterfactual_explanations = exp.CounterfactualExamples(
return exp.CounterfactualExamples(
data_interface=self.data_interface,
final_cfs_df=final_cfs_df,
test_instance_df=test_instance_df,
final_cfs_df_sparse=final_cfs_df_sparse,
posthoc_sparsity_param=posthoc_sparsity_param,
desired_class=desired_class)

return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations])

def do_cf_initializations(self, total_CFs, algorithm, features_to_vary):
"""Intializes TF variables required for CF generation."""

Expand Down
23 changes: 11 additions & 12 deletions dice_ml/explainer_interfaces/dice_tensorflow2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import tensorflow as tf

from dice_ml import diverse_counterfactuals as exp
from dice_ml.counterfactual_explanations import CounterfactualExplanations
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase


Expand Down Expand Up @@ -49,20 +48,22 @@ def __init__(self, data_interface, model_interface):
self.hyperparameters = [1, 1, 1] # proximity_weight, diversity_weight, categorical_penalty
self.optimizer_weights = [] # optimizer, learning_rate

def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opposite", proximity_weight=0.5,
diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF",
features_to_vary="all", permitted_range=None, yloss_type="hinge_loss",
diversity_loss_type="dpp_style:inverse_dist", feature_weights="inverse_mad",
optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500, max_iter=5000,
project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False,
init_near_query_instance=True, tie_random=False, stopping_threshold=0.5,
posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000):
def _generate_counterfactuals(self, query_instance, total_CFs,
desired_class="opposite", desired_range=None,
proximity_weight=0.5,
diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF", features_to_vary="all",
permitted_range=None, yloss_type="hinge_loss", diversity_loss_type="dpp_style:inverse_dist",
feature_weights="inverse_mad", optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500,
max_iter=5000, project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False,
init_near_query_instance=True, tie_random=False, stopping_threshold=0.5,
posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000):
"""Generates diverse counterfactual explanations

:param query_instance: Test point of interest. A dictionary of feature names and values or a single row dataframe
:param total_CFs: Total number of counterfactuals required.
:param desired_class: Desired counterfactual class - can take 0 or 1. Default value is "opposite" to the
outcome class of query_instance for binary classification.
:param desired_range: Not supported currently.
:param proximity_weight: A positive float. Larger this weight, more close the counterfactuals are to
the query_instance.
:param diversity_weight: A positive float. Larger this weight, more diverse the counterfactuals are.
Expand Down Expand Up @@ -136,16 +137,14 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp
init_near_query_instance, tie_random, stopping_threshold,
posthoc_sparsity_param, posthoc_sparsity_algorithm, limit_steps_ls)

counterfactual_explanations = exp.CounterfactualExamples(
return exp.CounterfactualExamples(
data_interface=self.data_interface,
final_cfs_df=final_cfs_df,
test_instance_df=test_instance_df,
final_cfs_df_sparse=final_cfs_df_sparse,
posthoc_sparsity_param=posthoc_sparsity_param,
desired_class=desired_class)

return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations])

def predict_fn(self, input_instance):
"""prediction function"""
temp_preds = self.model.get_output(input_instance).numpy()
Expand Down
Loading