-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpgm_explainer.py
112 lines (93 loc) · 4.14 KB
/
pgm_explainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy as np
import pandas as pd
import torch
from pgmpy.estimators.CITests import chi_square
from scipy.special import softmax
from torch_geometric.utils import k_hop_subgraph
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Node_Explainer:
def __init__(
self,
model,
edge_index,
X,
num_layers,
mode=0,
print_result=1
):
self.model = model
self.model.eval()
self.edge_index = edge_index
self.X = X
self.num_layers = num_layers
self.mode = mode
self.print_result = print_result
def perturb_features_on_node(self, feature_matrix, node_idx, random=0, mode=0):
# return a random perturbed feature matrix
# random = 0 for nothing, 1 for random.
# mode = 0 for random 0-1, 1 for scaling with original feature
X_perturb = feature_matrix
if mode == 0:
if random == 0:
perturb_array = X_perturb[node_idx]
elif random == 1:
perturb_array = np.random.randint(2, size=X_perturb[node_idx].shape[0])
X_perturb[node_idx] = perturb_array
elif mode == 1:
if random == 0:
perturb_array = X_perturb[node_idx]
elif random == 1:
perturb_array = np.multiply(X_perturb[node_idx],
np.random.uniform(low=0.0, high=2.0, size=X_perturb[node_idx].shape[0]))
X_perturb[node_idx] = perturb_array
return X_perturb
def explain(self, node_idx, target, num_samples=100, top_node=None, p_threshold=0.05, pred_threshold=0.1):
neighbors, _, _, _ = k_hop_subgraph(node_idx, self.num_layers, self.edge_index)
neighbors = neighbors.cpu().detach().numpy()
if (node_idx not in neighbors):
neighbors = np.append(neighbors, node_idx)
pred_torch = self.model(self.X, self.edge_index).cpu()
soft_pred = np.asarray([softmax(np.asarray(pred_torch[node_].data)) for node_ in range(self.X.shape[0])])
pred_node = np.asarray(pred_torch[node_idx].data)
label_node = np.argmax(pred_node)
soft_pred_node = softmax(pred_node)
Samples = []
Pred_Samples = []
for iteration in range(num_samples):
X_perturb = self.X.cpu().detach().numpy()
sample = []
for node in neighbors:
seed = np.random.randint(2)
if seed == 1:
latent = 1
X_perturb = self.perturb_features_on_node(X_perturb, node, random=seed)
else:
latent = 0
sample.append(latent)
X_perturb_torch = torch.tensor(X_perturb, dtype=torch.float).to(device)
pred_perturb_torch = self.model(X_perturb_torch, self.edge_index).cpu()
soft_pred_perturb = np.asarray(
[softmax(np.asarray(pred_perturb_torch[node_].data)) for node_ in range(self.X.shape[0])])
sample_bool = []
for node in neighbors:
if (soft_pred_perturb[node, target] + pred_threshold) < soft_pred[node, target]:
sample_bool.append(1)
else:
sample_bool.append(0)
Samples.append(sample)
Pred_Samples.append(sample_bool)
Samples = np.asarray(Samples)
Pred_Samples = np.asarray(Pred_Samples)
Combine_Samples = Samples - Samples
for s in range(Samples.shape[0]):
Combine_Samples[s] = np.asarray(
[Samples[s, i] * 10 + Pred_Samples[s, i] + 1 for i in range(Samples.shape[1])])
data = pd.DataFrame(Combine_Samples)
data = data.rename(columns={0: "A", 1: "B"}) # Trick to use chi_square test on first two data columns
ind_ori_to_sub = dict(zip(neighbors, list(data.columns)))
p_values = []
for node in neighbors:
chi2, p = chi_square(ind_ori_to_sub[node], ind_ori_to_sub[node_idx], [], data)
p_values.append(p)
pgm_stats = dict(zip(neighbors, p_values))
return pgm_stats