diff --git a/layers.py b/layers.py index ece9e3345..58dc02a13 100644 --- a/layers.py +++ b/layers.py @@ -11,6 +11,48 @@ class GraphAttentionLayer(nn.Module): def __init__(self, in_features, out_features, dropout, alpha, concat=True): super(GraphAttentionLayer, self).__init__() + self.dropout = dropout + self.in_features = in_features + self.out_features = out_features + self.alpha = alpha + self.concat = concat + + self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) + nn.init.xavier_uniform_(self.W.data, gain=1.414) + self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) + nn.init.xavier_uniform_(self.a.data, gain=1.414) + + self.leakyrelu = nn.LeakyReLU(self.alpha) + + def forward(self, input, adj): + h = torch.mm(input, self.W) + N = h.size()[0] + + a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features) + e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) + + zero_vec = -9e15*torch.ones_like(e) + attention = torch.where(adj > 0, e, zero_vec) + attention = F.softmax(attention, dim=1) + attention = F.dropout(attention, self.dropout, training=self.training) + h_prime = torch.matmul(attention, h) + + if self.concat: + return F.elu(h_prime) + else: + return h_prime + + def __repr__(self): + return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' + + +class SpGraphAttentionLayer(nn.Module): + """ + Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 + """ + + def __init__(self, in_features, out_features, dropout, alpha, concat=True): + super(SpGraphAttentionLayer, self).__init__() self.in_features = in_features self.out_features = out_features self.alpha = alpha diff --git a/models.py b/models.py index 3cdc05d5c..59ee34a0f 100644 --- a/models.py +++ b/models.py @@ -1,27 +1,48 @@ import torch import torch.nn as nn import torch.nn.functional as F -from layers import GraphAttentionLayer +from layers import GraphAttentionLayer, SpGraphAttentionLayer class GAT(nn.Module): def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads): + """Dense version of GAT.""" super(GAT, self).__init__() self.dropout = dropout - self.attentions = [GraphAttentionLayer(nfeat, - nhid, - dropout=dropout, - alpha=alpha, - concat=True) for _ in range(nheads)] + self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)] for i, attention in enumerate(self.attentions): self.add_module('attention_{}'.format(i), attention) - self.out_att = GraphAttentionLayer(nhid * nheads, - nclass, - dropout=dropout, - alpha=alpha, - concat=False) + self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False) + + def forward(self, x, adj): + x = F.dropout(x, self.dropout, training=self.training) + x = torch.cat([att(x, adj) for att in self.attentions], dim=1) + x = F.dropout(x, self.dropout, training=self.training) + x = F.elu(self.out_att(x, adj)) + return F.log_softmax(x, dim=1) + + +class SpGAT(nn.Module): + def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads): + """Sparse version of GAT.""" + super(SpGAT, self).__init__() + self.dropout = dropout + + self.attentions = [SpGraphAttentionLayer(nfeat, + nhid, + dropout=dropout, + alpha=alpha, + concat=True) for _ in range(nheads)] + for i, attention in enumerate(self.attentions): + self.add_module('attention_{}'.format(i), attention) + + self.out_att = SpGraphAttentionLayer(nhid * nheads, + nclass, + dropout=dropout, + alpha=alpha, + concat=False) def forward(self, x, adj): x = F.dropout(x, self.dropout, training=self.training) diff --git a/train.py b/train.py index e3e5c4311..88ffbcda7 100644 --- a/train.py +++ b/train.py @@ -14,12 +14,13 @@ from torch.autograd import Variable from utils import load_data, accuracy -from models import GAT +from models import GAT, SpGAT # Training settings parser = argparse.ArgumentParser() parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.') parser.add_argument('--fastmode', action='store_true', default=False, help='Validate during training pass.') +parser.add_argument('--sparse', action='store_true', default=False, help='GAT with sparse version or not.') parser.add_argument('--seed', type=int, default=72, help='Random seed.') parser.add_argument('--epochs', type=int, default=10000, help='Number of epochs to train.') parser.add_argument('--lr', type=float, default=0.005, help='Initial learning rate.') @@ -43,12 +44,20 @@ adj, features, labels, idx_train, idx_val, idx_test = load_data() # Model and optimizer -model = GAT(nfeat=features.shape[1], - nhid=args.hidden, - nclass=int(labels.max()) + 1, - dropout=args.dropout, - nheads=args.nb_heads, - alpha=args.alpha) +if args.sparse: + model = SpGAT(nfeat=features.shape[1], + nhid=args.hidden, + nclass=int(labels.max()) + 1, + dropout=args.dropout, + nheads=args.nb_heads, + alpha=args.alpha) +else: + model = GAT(nfeat=features.shape[1], + nhid=args.hidden, + nclass=int(labels.max()) + 1, + dropout=args.dropout, + nheads=args.nb_heads, + alpha=args.alpha) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)