Skip to content

Commit

Permalink
Make PGMExplainer method GPU compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
Mohammadamin Khashkhashi committed Dec 29, 2020
1 parent 11a2032 commit 1450272
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion benchmarks/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def create_dataset(self):
infection_sources = defaultdict(list)
for idx, (u, v) in enumerate(zip(*data.edge_index.cpu().numpy())):
if features[u, -1]:
infection_sources[v].append(idx)
infection_sources[int(v)].append(idx)
explanations = []
for node, sources in infection_sources.items():
if node in infected_nodes: # no edge explanation for these nodes
Expand Down
11 changes: 6 additions & 5 deletions pgm_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from scipy.special import softmax
from torch_geometric.utils import k_hop_subgraph

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Node_Explainer:
def __init__(
Expand Down Expand Up @@ -47,12 +48,12 @@ def perturb_features_on_node(self, feature_matrix, node_idx, random=0, mode=0):

def explain(self, node_idx, target, num_samples=100, top_node=None, p_threshold=0.05, pred_threshold=0.1):
neighbors, _, _, _ = k_hop_subgraph(node_idx, self.num_layers, self.edge_index)
neighbors = neighbors.detach().numpy()
neighbors = neighbors.cpu().detach().numpy()

if (node_idx not in neighbors):
neighbors = np.append(neighbors, node_idx)

pred_torch = self.model(self.X, self.edge_index)
pred_torch = self.model(self.X, self.edge_index).cpu()
soft_pred = np.asarray([softmax(np.asarray(pred_torch[node_].data)) for node_ in range(self.X.shape[0])])

pred_node = np.asarray(pred_torch[node_idx].data)
Expand All @@ -64,7 +65,7 @@ def explain(self, node_idx, target, num_samples=100, top_node=None, p_threshold=

for iteration in range(num_samples):

X_perturb = self.X.detach().numpy()
X_perturb = self.X.cpu().detach().numpy()
sample = []
for node in neighbors:
seed = np.random.randint(2)
Expand All @@ -75,8 +76,8 @@ def explain(self, node_idx, target, num_samples=100, top_node=None, p_threshold=
latent = 0
sample.append(latent)

X_perturb_torch = torch.tensor(X_perturb, dtype=torch.float)
pred_perturb_torch = self.model(X_perturb_torch, self.edge_index)
X_perturb_torch = torch.tensor(X_perturb, dtype=torch.float).to(device)
pred_perturb_torch = self.model(X_perturb_torch, self.edge_index).cpu()
soft_pred_perturb = np.asarray(
[softmax(np.asarray(pred_perturb_torch[node_].data)) for node_ in range(self.X.shape[0])])

Expand Down

0 comments on commit 1450272

Please sign in to comment.