forked from layumi/Person_reID_baseline_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
118 lines (88 loc) · 3.48 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective
Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang
Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking
Paper: https://arxiv.org/abs/2012.07620v2
======================================================================
On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms
with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe
that our method achieves comparable or even better retrieval results on the other four
image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652,
with limited time cost.
"""
import pickle
import numpy as np
import torch
def load_pickle(pickle_path):
with open(pickle_path, 'rb') as f:
data = pickle.load(f)
return data
def save_pickle(pickle_path, data):
with open(pickle_path, 'wb') as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
def pairwise_squared_distance(x):
'''
x : (n_samples, n_points, dims)
return : (n_samples, n_points, n_points)
'''
x2s = (x * x).sum(-1, keepdim=True)
return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2)
def pairwise_distance(x, y):
m, n = x.size(0), y.size(0)
x = x.view(m, -1)
y = y.view(n, -1)
dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n,m).t()
dist.addmm_(1, -2, x, y.t())
return dist
def cosine_similarity(x, y):
m, n = x.size(0), y.size(0)
x = x.view(m, -1)
y = y.view(n, -1)
y = y.t()
score = torch.mm(x, y)
return score
def evaluate_ranking_list(indices, query_label, query_cam, gallery_label, gallery_cam):
CMC = np.zeros((len(gallery_label)), dtype=np.int)
ap = 0.0
for i in range(len(query_label)):
ap_tmp, CMC_tmp = evaluate(indices[i],query_label[i], query_cam[i], gallery_label, gallery_cam)
if CMC_tmp[0]==-1:
continue
CMC = CMC + CMC_tmp
ap += ap_tmp
CMC = CMC.astype(np.float32)
CMC = CMC/len(query_label) #average CMC
print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label)))
def evaluate(index, ql,qc,gl,gc):
query_index = np.argwhere(gl==ql)
camera_index = np.argwhere(gc==qc)
good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
junk_index1 = np.argwhere(gl==-1)
junk_index2 = np.intersect1d(query_index, camera_index)
junk_index = np.append(junk_index2, junk_index1) #.flatten())
CMC_tmp = compute_mAP(index, good_index, junk_index)
return CMC_tmp
def compute_mAP(index, good_index, junk_index):
ap = 0
cmc = np.zeros((len(index)), dtype=np.int)
if good_index.size==0: # if empty
cmc[0] = -1
return ap,cmc
# remove junk_index
mask = np.in1d(index, junk_index, invert=True)
index = index[mask]
# find good_index index
ngood = len(good_index)
mask = np.in1d(index, good_index)
rows_good = np.argwhere(mask==True)
rows_good = rows_good.flatten()
cmc[rows_good[0]:] = 1
for i in range(ngood):
d_recall = 1.0/ngood
precision = (i+1)*1.0/(rows_good[i]+1)
if rows_good[i]!=0:
old_precision = i*1.0/rows_good[i]
else:
old_precision=1.0
ap = ap + d_recall*(old_precision + precision)/2
return ap, cmc