Skip to content

Commit

Permalink
modify data loader
Browse files Browse the repository at this point in the history
  • Loading branch information
ChandlerBang committed Feb 25, 2020
1 parent 844bccc commit f0913d5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
2 changes: 1 addition & 1 deletion deeprobust/graph/data/attacked_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class PtbDataset:
This class manages pre-attacked/perturbed adjacency matrix on different datasets
'''

def __init__(self, root, name, attack_method='mettack', require_lcc=True, transform=None):
def __init__(self, root, name, attack_method='mettack', transform=None):
assert attack_method == 'mettack', \
'Currently the database only stores graphs perturbed by 5% mettack'

Expand Down
32 changes: 29 additions & 3 deletions deeprobust/graph/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class Dataset():

def __init__(self, root, name, setting='nettack', seed=None):
def __init__(self, root, name, setting='nettack', seed=None, require_mask=False):
self.name = name.lower()
self.setting = setting.lower()

Expand All @@ -21,10 +21,13 @@ def __init__(self, root, name, setting='nettack', seed=None):
self.root = osp.expanduser(osp.normpath(root))
self.data_filename = osp.join(root, self.name)
self.data_filename += '.npz'
self.require_mask = require_mask

self.require_lcc = True if setting == 'nettack' else False
self.adj, self.features, self.labels = self.load_data()
self.idx_train, self.idx_val, self.idx_test = self.get_train_val_test()
if self.require_mask:
self.get_mask()

def get_train_val_test(self):

Expand Down Expand Up @@ -105,6 +108,29 @@ def largest_connected_components(self, adj, n_components=1):
return nodes_to_keep

def __repr__(self):
return '{}()'.format(self.name)

return f'{self.name}(adj_shape={self.adj.shape}, feature_shape={self.features.shape})'

def get_mask(self):
idx_train, idx_val, idx_test = self.idx_train, self.idx_val, self.idx_test
labels = self.onehot(self.labels)

def get_mask(idx):
mask = np.zeros(labels.shape[0], dtype=np.bool)
mask[idx] = 1
return mask

def get_y(idx):
mx = np.zeros(labels.shape)
mx[idx] = labels[idx]
return mx

self.train_mask = get_mask(self.idx_train)
self.val_mask = get_mask(self.idx_val)
self.test_mask = get_mask(self.idx_test)
self.y_train, self.y_val, self.y_test = get_y(idx_train), get_y(idx_val), get_y(idx_test)

def onehot(self, labels):
eye = np.identity(labels.max() + 1)
onehot_mx = eye[labels]
return onehot_mx

4 changes: 2 additions & 2 deletions deeprobust/graph/defense/gcn_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, nfeat, nhid, nclass, dropout=0.5, lr=0.01, weight_decay=5e-4,
super(GCNSVD, self).__init__(nfeat, nhid, nclass, dropout, lr, weight_decay, with_relu, with_bias, device=device)
self.device = device

def fit(self, features, adj, labels, idx_train, k=50, train_iters=200, initialize=True, verbose=True):
def fit(self, features, adj, labels, idx_train, idx_val=None, k=50, train_iters=200, initialize=True, verbose=True):

modified_adj = self.truncatedSVD(adj, k=k)
# modified_adj_tensor = utils.sparse_mx_to_torch_sparse_tensor(self.modified_adj)
Expand All @@ -27,7 +27,7 @@ def fit(self, features, adj, labels, idx_train, k=50, train_iters=200, initializ
self.modified_adj = modified_adj
self.features = features
self.labels = labels
super().fit(features, modified_adj, labels, idx_train, train_iters=train_iters, initialize=initialize, verbose=verbose)
super().fit(features, modified_adj, labels, idx_train, idx_val, train_iters=train_iters, initialize=initialize, verbose=verbose)

def truncatedSVD(self, data, k=50):
print(f'=== GCN-SVD: rank={k} ===')
Expand Down

0 comments on commit f0913d5

Please sign in to comment.