Skip to content

Commit

Permalink
Change Edge Gradients to be not absolute
Browse files Browse the repository at this point in the history
  • Loading branch information
m30m committed Feb 1, 2021
1 parent c8bbc01 commit ee81463
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion explain_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def explain_sa(model, node_idx, x, edge_index, target, include_edges=None):
saliency = Saliency(model_forward)
input_mask = torch.ones(edge_index.shape[1]).requires_grad_(True).to(device)
saliency_mask = saliency.attribute(input_mask, target=target,
additional_forward_args=(model, node_idx, x, edge_index))
additional_forward_args=(model, node_idx, x, edge_index), abs=False)

edge_mask = saliency_mask.cpu().numpy()
return edge_mask
Expand Down

0 comments on commit ee81463

Please sign in to comment.