Skip to content

Commit

Permalink
Add custom target class for gnn explainer
Browse files Browse the repository at this point in the history
  • Loading branch information
m30m committed Jan 13, 2021
1 parent 718f766 commit 2a7f0d5
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 3 deletions.
7 changes: 4 additions & 3 deletions explain_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
from captum.attr import Saliency, IntegratedGradients, LayerGradCam
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn import GNNExplainer, MessagePassing
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import to_networkx

from pgm_explainer import Node_Explainer
from gnn_explainer import TargetedGNNExplainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Expand Down Expand Up @@ -244,8 +245,8 @@ def explain_occlusion_undirected(model, node_idx, x, edge_index, target, include


def explain_gnnexplainer(model, node_idx, x, edge_index, target, include_edges=None):
explainer = GNNExplainer(model, epochs=200, log=False)
node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index)
explainer = TargetedGNNExplainer(model, epochs=200, log=False)
node_feat_mask, edge_mask = explainer.explain_node_with_target(node_idx, x, edge_index, target_class=target)
return edge_mask.cpu().numpy()


Expand Down
84 changes: 84 additions & 0 deletions gnn_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
from torch_geometric.nn import GNNExplainer
from tqdm import tqdm

EPS = 1e-15


class TargetedGNNExplainer(GNNExplainer):
def __loss__(self, node_idx, log_logits, target_class):
loss = -log_logits[node_idx, target_class]

m = self.edge_mask.sigmoid()
loss = loss + self.coeffs['edge_size'] * m.sum()
ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
loss = loss + self.coeffs['edge_ent'] * ent.mean()

m = self.node_feat_mask.sigmoid()
loss = loss + self.coeffs['node_feat_size'] * m.sum()
ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
loss = loss + self.coeffs['node_feat_ent'] * ent.mean()

return loss

def explain_node_with_target(self, node_idx, x, edge_index, target_class, **kwargs):
r"""Learns and returns a node feature mask and an edge mask that play a
crucial role to explain the prediction made by the GNN for node
:attr:`node_idx`.
Args:
node_idx (int): The node to explain.
x (Tensor): The node feature matrix.
edge_index (LongTensor): The edge indices.
**kwargs (optional): Additional arguments passed to the GNN module.
:rtype: (:class:`Tensor`, :class:`Tensor`)
"""

self.model.eval()
self.__clear_masks__()

num_edges = edge_index.size(1)

# Only operate on a k-hop subgraph around `node_idx`.
x, edge_index, mapping, hard_edge_mask, kwargs = self.__subgraph__(
node_idx, x, edge_index, **kwargs)

# Get the initial prediction.
if target_class is None:
with torch.no_grad():
log_logits = self.model(x=x, edge_index=edge_index, **kwargs)
pred_label = log_logits.argmax(dim=-1)
target_class = pred_label[mapping].item()

self.__set_masks__(x, edge_index)
self.to(x.device)

optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
lr=self.lr)

if self.log: # pragma: no cover
pbar = tqdm(total=self.epochs)
pbar.set_description(f'Explain node {node_idx}')

for epoch in range(1, self.epochs + 1):
optimizer.zero_grad()
h = x * self.node_feat_mask.view(1, -1).sigmoid()
log_logits = self.model(x=h, edge_index=edge_index, **kwargs)
loss = self.__loss__(mapping, log_logits, target_class)
loss.backward()
optimizer.step()

if self.log: # pragma: no cover
pbar.update(1)

if self.log: # pragma: no cover
pbar.close()

node_feat_mask = self.node_feat_mask.detach().sigmoid()
edge_mask = self.edge_mask.new_zeros(num_edges)
edge_mask[hard_edge_mask] = self.edge_mask.detach().sigmoid()

self.__clear_masks__()

return node_feat_mask, edge_mask

0 comments on commit 2a7f0d5

Please sign in to comment.