forked from geohot/twitchslam
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathframe.py
125 lines (104 loc) · 4.11 KB
/
frame.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import cv2
import numpy as np
from scipy.spatial import cKDTree
from constants import RANSAC_RESIDUAL_THRES, RANSAC_MAX_TRIALS
np.set_printoptions(suppress=True)
from skimage.measure import ransac
from helpers import add_ones, poseRt, fundamentalToRt, normalize, EssentialMatrixTransform, myjet
import logging
logging.basicConfig(format='%(asctime)s [%(levelname)s]= %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class NoFrameMatchError(Exception):
pass
def extractFeatures(img):
orb = cv2.ORB_create()
# detection
pts = cv2.goodFeaturesToTrack(np.mean(img, axis=2).astype(np.uint8), 3000, qualityLevel=0.01, minDistance=7)
# extraction
kps = [cv2.KeyPoint(x=f[0][0], y=f[0][1], _size=20) for f in pts]
kps, descriptors = orb.compute(img, kps)
# return pts and descriptors
return np.array([(kp.pt[0], kp.pt[1]) for kp in kps]), descriptors
class Frame(object):
def __init__(self, img, K, pose=np.eye(4), verts=None):
self.K = np.array(K)
self.pose = np.array(pose)
self.img = img
if img is not None:
self.h, self.w = img.shape[0:2]
if verts is None:
self.key_points, self.descriptors = extractFeatures(img)
else:
assert len(verts) < 256
self.key_points, self.descriptors = verts, np.array(list(range(len(verts))) * 32, np.uint8).reshape(32, len(verts)).T
self.pts = [None]*len(self.key_points)
else:
# fill in later
self.h, self.w = 0, 0
self.key_points, self.descriptors, self.pts = None, None, None
def join_point(self, point, idx):
assert self.pts[idx] is None
self.pts[idx] = point
def leave_point(self, point, idx):
assert self.pts[idx] is not None
assert point in self.pts
del self.pts[idx]
# inverse of intrinsics matrix
@property
def Kinv(self):
if not hasattr(self, '_Kinv'):
self._Kinv = np.linalg.inv(self.K)
return self._Kinv
# normalized keypoints
@property
def kps(self):
if not hasattr(self, '_kps'):
self._kps = normalize(self.Kinv, self.key_points)
return self._kps
# KD tree of unnormalized keypoints
@property
def kd(self):
if not hasattr(self, '_kd'):
self._kd = cKDTree(self.key_points)
return self._kd
def match_frames(frame_1: Frame, frame_2: Frame):
bf = cv2.BFMatcher(cv2.NORM_HAMMING)
matches = bf.knnMatch(frame_1.descriptors, frame_2.descriptors, k=2)
# Lowe's ratio test
ret = []
idx1, idx2 = [], []
idx1s, idx2s = set(), set()
for m,n in matches:
if m.distance < 0.75*n.distance:
p1 = frame_1.kps[m.queryIdx]
p2 = frame_2.kps[m.trainIdx]
# be within orb distance 32
if m.distance < 32:
# keep around indices
# TODO: refactor this to not be O(N^2)
if m.queryIdx not in idx1s and m.trainIdx not in idx2s:
idx1.append(m.queryIdx)
idx2.append(m.trainIdx)
idx1s.add(m.queryIdx)
idx2s.add(m.trainIdx)
ret.append((p1, p2))
# no duplicates
assert(len(set(idx1)) == len(idx1))
assert(len(set(idx2)) == len(idx2))
if len(ret) < 8:
logger.warning("Skipping match of frame {} to frame {}".format(frame_1.id, frame_2.id))
raise NoFrameMatchError
ret = np.array(ret)
idx1 = np.array(idx1)
idx2 = np.array(idx2)
# fit matrix
model, inliers = ransac((ret[:, 0], ret[:, 1]),
EssentialMatrixTransform,
min_samples=8,
residual_threshold=RANSAC_RESIDUAL_THRES,
max_trials=RANSAC_MAX_TRIALS)
logger.info("Quality: {}".format(np.mean(ret[:, 0]-ret[:, 1])))
logger.info("Matches: %d -> %d -> %d -> %d" % (len(frame_1.descriptors), len(matches), len(inliers), sum(inliers)))
return idx1[inliers], idx2[inliers], fundamentalToRt(model.params)