forked from szq0214/DSOD
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
133 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import matplotlib.font_manager as fm | ||
import matplotlib.cm as cm | ||
# %matplotlib inline | ||
|
||
plt.rcParams['figure.figsize'] = (10, 10) | ||
plt.rcParams['image.interpolation'] = 'nearest' | ||
|
||
# Make sure that caffe is on the python path: | ||
import os | ||
import sys | ||
sys.path.append('./python') | ||
|
||
import caffe | ||
|
||
# gpu: | ||
# caffe.set_device(0) | ||
# caffe.set_mode_gpu() | ||
|
||
# cpu: | ||
caffe.set_mode_cpu() | ||
|
||
from google.protobuf import text_format | ||
from caffe.proto import caffe_pb2 | ||
|
||
# load PASCAL VOC labels | ||
labelmap_file = 'data/VOC0712/labelmap_voc.prototxt' | ||
file = open(labelmap_file, 'r') | ||
labelmap = caffe_pb2.LabelMap() | ||
text_format.Merge(str(file.read()), labelmap) | ||
|
||
def get_labelname(labelmap, labels): | ||
num_labels = len(labelmap.item) | ||
labelnames = [] | ||
if type(labels) is not list: | ||
labels = [labels] | ||
for label in labels: | ||
found = False | ||
for i in xrange(0, num_labels): | ||
if label == labelmap.item[i].label: | ||
found = True | ||
labelnames.append(labelmap.item[i].display_name) | ||
break | ||
assert found == True | ||
return labelnames | ||
|
||
#Load the net in the test phase for inference, and configure input preprocessing. | ||
model_def = 'examples/07+12+coco/deploy.prototxt' | ||
model_weights = 'examples/07+12+coco/DSOD300_VOC0712+coco.caffemodel' | ||
|
||
net = caffe.Net(model_def, # defines the structure of the model | ||
model_weights, # contains the trained weights | ||
caffe.TEST) # use test mode (e.g., don't perform dropout) | ||
|
||
# input preprocessing: 'data' is the name of the input blob == net.inputs[0] | ||
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) | ||
transformer.set_transpose('data', (2, 0, 1)) | ||
transformer.set_mean('data', np.array([104,117,123])) # mean pixel | ||
transformer.set_raw_scale('data', 255) # the reference model operates on images in [0,255] range instead of [0,1] | ||
transformer.set_channel_swap('data', (2,1,0)) # the reference model has channels in BGR order instead of RGB | ||
|
||
# DSOD detection | ||
|
||
# set net to batch size of 1 | ||
image_resize = 300 | ||
net.blobs['data'].reshape(1,3,image_resize,image_resize) | ||
|
||
# set colors | ||
colors = plt.cm.hsv(np.linspace(0, 1, 21)).tolist() | ||
|
||
#Load an image. | ||
img = "examples/images/cat.jpg" | ||
image = caffe.io.load_image(img) | ||
plt.imshow(image) | ||
|
||
#Run the net and examine the top_k results | ||
|
||
transformed_image = transformer.preprocess('data', image) | ||
net.blobs['data'].data[...] = transformed_image | ||
|
||
# Forward pass. | ||
detections = net.forward()['detection_out'] | ||
|
||
# Parse the outputs. | ||
det_label = detections[0,0,:,1] | ||
det_conf = detections[0,0,:,2] | ||
det_xmin = detections[0,0,:,3] | ||
det_ymin = detections[0,0,:,4] | ||
det_xmax = detections[0,0,:,5] | ||
det_ymax = detections[0,0,:,6] | ||
|
||
# Get detections with confidence higher than 0.6. | ||
top_indices = [i for i, conf in enumerate(det_conf) if conf >= 0.6] | ||
|
||
top_conf = det_conf[top_indices] | ||
top_label_indices = det_label[top_indices].tolist() | ||
top_labels = get_labelname(labelmap, top_label_indices) | ||
top_xmin = det_xmin[top_indices] | ||
top_ymin = det_ymin[top_indices] | ||
top_xmax = det_xmax[top_indices] | ||
top_ymax = det_ymax[top_indices] | ||
|
||
#Plot the boxes | ||
plt.imshow(image) | ||
currentAxis = plt.gca() | ||
|
||
for i in xrange(top_conf.shape[0]): | ||
xmin = int(round(top_xmin[i] * image.shape[1])) | ||
ymin = int(round(top_ymin[i] * image.shape[0])) | ||
xmax = int(round(top_xmax[i] * image.shape[1])) | ||
ymax = int(round(top_ymax[i] * image.shape[0])) | ||
score = top_conf[i] | ||
label = int(top_label_indices[i]) | ||
label_name = top_labels[i] | ||
display_txt = '%s: %.2f'%(label_name, score) | ||
coords = (xmin, ymin), xmax-xmin+1, ymax-ymin+1 | ||
color = colors[label] | ||
currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=6)) | ||
currentAxis.text(xmin, ymin, display_txt, bbox={'facecolor':color, 'alpha':0.5}, fontproperties=fm.FontProperties(size=22)) | ||
|
||
currentAxis.axes.get_xaxis().set_visible(False) | ||
currentAxis.axes.get_yaxis().set_visible(False) | ||
|
||
plt.show() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters