-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpose_loader.py
64 lines (59 loc) · 1.97 KB
/
pose_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from struct import unpack
import numpy as np
# Load poses (i.e., labels and depths) from file. Returns a
# tuple containing the original image width and height, a list of lists
# where each inner list has the labels for each sample, and another list of
# lists where each inner list has the depths for each sample. Labels and depths
# are flattened, so the number of entries in each sample vector is equal to
# the number of pixels in each sample image. Finally, the labels are
# as follows.
#
# 0 = Nothing (i.e., the background)
# 1 = Pelvis
# 2 = Spine 1
# 3 = Spine 2
# 4 = Spine 3
# 5 = Left Upper Arm
# 6 = Left Lower Arm
# 7 = Left Hand
# 8 = Right Upper Arm
# 9 = Right Lower Arm
# 10 = Right Hand
# 11 = Neck
# 12 = Head
# 13 = Left Thigh
# 14 = Left Calf
# 15 = Left Foot
# 16 = Right Thigh
# 17 = Right Calf
# 18 = Right Foot
class PoseLoader:
"""
Loads depths and labels in a batch file, one record at a time
"""
def __init__(self, filename):
self.filename = filename
infile = open(self.filename, 'rb')
self.infile = infile
self.total_n = unpack('<i', infile.read(4))[0]
self.w = unpack('<i', infile.read(4))[0]
self.h = unpack('<i', infile.read(4))[0]
self.curr_n = 0
def load_next_pose(self):
if self.curr_n >= self.total_n:
return None
l = list(unpack('<' + 'B' * (self.w * self.h), self.infile.read(self.w * self.h)))
z = list(unpack('<' + 'f' * (self.w * self.h), self.infile.read(self.w * self.h * 4)))
self.curr_n += 1
if self.curr_n >= self.total_n:
self.infile.close()
return np.array(z).astype(np.float16), np.array(l).astype(np.uint8)
def next_batch(self, size):
batch = []
for i in range(size):
img = self.load_next_pose()
if img is not None:
batch.append(img)
else:
break
return np.array(batch)