-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdemo.py
executable file
·128 lines (109 loc) · 4.59 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.nn as nn
import torch.utils.data
import numpy as np
from opt import opt
from dataloader import ImageLoader, DetectionLoader, DetectionProcessor, DataWriter, Mscoco
from yolo.util import write_results, dynamic_write_results
from SPPE.src.main_fast_inference import *
import os
import sys
from tqdm import tqdm
import time
from fn import getTime
from pPose_nms import pose_nms, write_json
from IPython import embed
args = opt
args.dataset = 'coco'
if not args.sp:
torch.multiprocessing.set_start_method('forkserver', force=True)
torch.multiprocessing.set_sharing_strategy('file_system')
if __name__ == "__main__":
inputpath = args.inputpath
inputlist = args.inputlist
mode = args.mode
if not os.path.exists(args.outputpath):
os.mkdir(args.outputpath)
if len(inputlist):
im_names = open(inputlist, 'r').readlines()
elif len(inputpath) and inputpath != '/':
for root, dirs, files in os.walk(inputpath):
im_names = files
else:
raise IOError('Error: must contain either --indir/--list')
# Load input images meanwhile start processes, threads
data_loader = ImageLoader(im_names, batchSize=args.detbatch, format='yolo', reso=int(args.inp_dim)).start()
# Load detection loader
print('Loading YOLO model..')
sys.stdout.flush() # for multithread displaying
det_loader = DetectionLoader(data_loader, batchSize=args.detbatch).start()
det_processor = DetectionProcessor(det_loader).start()
# Load pose model
pose_dataset = Mscoco() # is_train, res, joints, rot_factor
if args.fast_inference:
pose_model = InferenNet_fast(4 * 1 + 1, pose_dataset)
else:
pose_model = InferenNet(4 * 1 + 1, pose_dataset)
pose_model.cuda()
pose_model.eval()
runtime_profile = {
'dt': [],
'pt': [],
'pn': []
}
# Init data writer, for video.
writer = DataWriter(args.save_video).start() # save_video default: False
data_len = data_loader.length()
im_names_desc = tqdm(range(data_len))
batchSize = args.posebatch
# embed()
for i in im_names_desc:
# for i in range(data_len):
start_time = getTime()
with torch.no_grad():
# Detection is handling here
(inps, orig_img, im_name, boxes, scores, pt1, pt2) = det_processor.read()
if boxes is None or boxes.nelement() == 0:
writer.save(None, None, None, None, None, orig_img, im_name.split('/')[-1])
continue
ckpt_time, det_time = getTime(start_time)
runtime_profile['dt'].append(det_time)
# Pose Estimation
datalen = inps.size(0)
leftover = 0
if (datalen) % batchSize: # left some
leftover = 1
num_batches = datalen // batchSize + leftover
hm = []
# embed()
for j in range(num_batches):
inps_j = inps[j*batchSize:min((j + 1)*batchSize, datalen)].cuda()
# important, apply pose_model
hm_j = pose_model(inps_j) #hm_j is a heatmap with size B*KP*H*W
hm.append(hm_j)
hm = torch.cat(hm)
ckpt_time, pose_time = getTime(ckpt_time)
runtime_profile['pt'].append(pose_time)
hm = hm.cpu() # hm is torch.tensor
writer.save(boxes, scores, hm, pt1, pt2, orig_img, im_name.split('/')[-1])
ckpt_time, post_time = getTime(ckpt_time)
runtime_profile['pn'].append(post_time)
if args.profile:# False
# TQDM
im_names_desc.set_description(
'det time: {dt:.3f} | pose time: {pt:.2f} | post processing: {pn:.4f}'.format(
dt=np.mean(runtime_profile['dt']), pt=np.mean(runtime_profile['pt']), pn=np.mean(runtime_profile['pn']))
)
print('===========================> Finish Model Running.')
if (args.save_img or args.save_video) and not args.vis_fast:
print('===========================> Rendering remaining images in the queue...')
print('===========================> If this step takes too long, you can enable the --vis_fast flag to use fast rendering (real-time).')
while(writer.running()):
pass
writer.stop()
final_result = writer.results()
# embed()
write_json(final_result, args.outputpath)