Skip to content

Commit

Permalink
use sparse adj to compute affinity matrix.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhihengli-UR committed Sep 7, 2020
1 parent 92e588c commit b59efc1
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def eval(test_dataloader, model, args):
for num_cg_bead in iter_num_cg_beads:
# try:
if args.inference_method == 'dsgpm':
hard_assign, _ = graph_cuts(fg_embed, dense_adj, num_cg_bead, args.bandwidth, device=args.device_for_affinity_matrix)
hard_assign, _ = graph_cuts(fg_embed, data.edge_index, num_cg_bead, args.bandwidth, device=args.device_for_affinity_matrix)
hard_assign = enforce_connectivity(hard_assign, edge_index_cpu)
elif args.inference_method == 'spec_cluster':
hard_assign = graph_cuts_with_adj(dense_adj, num_cg_bead)
Expand Down
14 changes: 8 additions & 6 deletions model/graph_cuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from sklearn.cluster import spectral_clustering


def graph_cuts(fg_embed, adj, num_cg, bandwidth=1.0, kernel='rbf', device=torch.device(0)):
affinity = compute_affinity(fg_embed, adj, bandwidth, kernel, device)
def graph_cuts(fg_embed, edge_index, num_cg, bandwidth=1.0, kernel='rbf', device=torch.device(0)):
affinity = compute_affinity(fg_embed, edge_index, bandwidth, kernel, device)

pred_cg_idx = spectral_clustering(affinity.cpu().numpy(), n_clusters=num_cg, assign_labels='discretize')
return pred_cg_idx, affinity
Expand All @@ -21,15 +21,17 @@ def graph_cuts_with_adj(adj, num_cg):
return pred_cg_idx


def compute_affinity(fg_embed, adj, bandwidth=1.0, kernel='rbf', device=torch.device(0)):
def compute_affinity(fg_embed, edge_index, bandwidth=1.0, kernel='rbf', device=torch.device(0)):
if kernel == 'rbf':
n, d = fg_embed.shape
num_nodes = fg_embed.shape[0]
fg_embed = fg_embed.to(device)
pairwise_dist = torch.norm(fg_embed.reshape(n, 1, d) - fg_embed.reshape(1, n, d), dim=2).to(torch.device(0))
pairwise_dist = torch.norm(fg_embed[edge_index[0]] - fg_embed[edge_index[1]], dim=1).to(torch.device(0))
# pairwise_dist = torch.norm(fg_embed.reshape(n, 1, d) - fg_embed.reshape(1, n, d), dim=2).to(torch.device(0))

pairwise_dist = pairwise_dist ** 2
affinity = torch.exp(-pairwise_dist / (2 * bandwidth ** 2))
affinity = affinity * adj
# affinity = affinity * adj
affinity = torch.sparse.LongTensor(edge_index, affinity, (num_nodes, num_nodes)).to_dense()
elif kernel == 'linear':
affinity = F.relu(fg_embed @ fg_embed.t())
else:
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def eval(fold, epoch, test_dataloader, model, args):
best_precision, best_recall, best_f_score = -1, 0, 0, 0

for anno_idx, gt_hard_assign in enumerate(gt_hard_assigns):
hard_assign, _ = graph_cuts(fg_embed, dense_adj, max_num_cg_beads[anno_idx], args.bandwidth)
hard_assign, _ = graph_cuts(fg_embed, data.edge_index, max_num_cg_beads[anno_idx], args.bandwidth)
try:
hard_assign = enforce_connectivity(hard_assign, edge_index_cpu)
except:
Expand Down

0 comments on commit b59efc1

Please sign in to comment.