diff --git a/DSOD300_detection_demo.py b/DSOD300_detection_demo.py new file mode 100644 index 0000000..8439a90 --- /dev/null +++ b/DSOD300_detection_demo.py @@ -0,0 +1,126 @@ +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.font_manager as fm +import matplotlib.cm as cm +# %matplotlib inline + +plt.rcParams['figure.figsize'] = (10, 10) +plt.rcParams['image.interpolation'] = 'nearest' + +# Make sure that caffe is on the python path: +import os +import sys +sys.path.append('./python') + +import caffe + +# gpu: +# caffe.set_device(0) +# caffe.set_mode_gpu() + +# cpu: +caffe.set_mode_cpu() + +from google.protobuf import text_format +from caffe.proto import caffe_pb2 + +# load PASCAL VOC labels +labelmap_file = 'data/VOC0712/labelmap_voc.prototxt' +file = open(labelmap_file, 'r') +labelmap = caffe_pb2.LabelMap() +text_format.Merge(str(file.read()), labelmap) + +def get_labelname(labelmap, labels): + num_labels = len(labelmap.item) + labelnames = [] + if type(labels) is not list: + labels = [labels] + for label in labels: + found = False + for i in xrange(0, num_labels): + if label == labelmap.item[i].label: + found = True + labelnames.append(labelmap.item[i].display_name) + break + assert found == True + return labelnames + +#Load the net in the test phase for inference, and configure input preprocessing. +model_def = 'examples/07+12+coco/deploy.prototxt' +model_weights = 'examples/07+12+coco/DSOD300_VOC0712+coco.caffemodel' + +net = caffe.Net(model_def, # defines the structure of the model + model_weights, # contains the trained weights + caffe.TEST) # use test mode (e.g., don't perform dropout) + +# input preprocessing: 'data' is the name of the input blob == net.inputs[0] +transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) +transformer.set_transpose('data', (2, 0, 1)) +transformer.set_mean('data', np.array([104,117,123])) # mean pixel +transformer.set_raw_scale('data', 255) # the reference model operates on images in [0,255] range instead of [0,1] +transformer.set_channel_swap('data', (2,1,0)) # the reference model has channels in BGR order instead of RGB + +# DSOD detection + +# set net to batch size of 1 +image_resize = 300 +net.blobs['data'].reshape(1,3,image_resize,image_resize) + +# set colors +colors = plt.cm.hsv(np.linspace(0, 1, 21)).tolist() + +#Load an image. +img = "examples/images/cat.jpg" +image = caffe.io.load_image(img) +plt.imshow(image) + +#Run the net and examine the top_k results + +transformed_image = transformer.preprocess('data', image) +net.blobs['data'].data[...] = transformed_image + +# Forward pass. +detections = net.forward()['detection_out'] + +# Parse the outputs. +det_label = detections[0,0,:,1] +det_conf = detections[0,0,:,2] +det_xmin = detections[0,0,:,3] +det_ymin = detections[0,0,:,4] +det_xmax = detections[0,0,:,5] +det_ymax = detections[0,0,:,6] + +# Get detections with confidence higher than 0.6. +top_indices = [i for i, conf in enumerate(det_conf) if conf >= 0.6] + +top_conf = det_conf[top_indices] +top_label_indices = det_label[top_indices].tolist() +top_labels = get_labelname(labelmap, top_label_indices) +top_xmin = det_xmin[top_indices] +top_ymin = det_ymin[top_indices] +top_xmax = det_xmax[top_indices] +top_ymax = det_ymax[top_indices] + +#Plot the boxes +plt.imshow(image) +currentAxis = plt.gca() + +for i in xrange(top_conf.shape[0]): + xmin = int(round(top_xmin[i] * image.shape[1])) + ymin = int(round(top_ymin[i] * image.shape[0])) + xmax = int(round(top_xmax[i] * image.shape[1])) + ymax = int(round(top_ymax[i] * image.shape[0])) + score = top_conf[i] + label = int(top_label_indices[i]) + label_name = top_labels[i] + display_txt = '%s: %.2f'%(label_name, score) + coords = (xmin, ymin), xmax-xmin+1, ymax-ymin+1 + color = colors[label] + currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=6)) + currentAxis.text(xmin, ymin, display_txt, bbox={'facecolor':color, 'alpha':0.5}, fontproperties=fm.FontProperties(size=22)) + +currentAxis.axes.get_xaxis().set_visible(False) +currentAxis.axes.get_yaxis().set_visible(False) + +plt.show() + diff --git a/README.md b/README.md index d969cce..3d64445 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,6 @@ The code is based on the SSD framework (https://github.com/weiliu89/caffe/tree/s If you use these models or find this helps your research, please cite: - @inproceedings{Shen2017DSOD, title = {DSOD: Learning Deeply Supervised Object Detectors from Scratch}, author = {Shen, Zhiqiang and Liu, Zhuang and Li, Jianguo and Jiang, Yu-Gang and Chen, Yurong and Xue, Xiangyang}, @@ -58,14 +57,12 @@ COCO test-dev 2015 result (COCO has more object categories than VOC dataset, so |:-------|:-----:|:-----:| | DSOD300 (COCO trainval) | 29.3 | [Download (87.2M)](https://drive.google.com/open?id=0B4cvsEOB5eUCYXoxcGRCbVFMNms) | - ## Preparation 0. Install SSD (https://github.com/weiliu89/caffe/tree/ssd) following the instructions there, including: (1) Install SSD caffe; (2) Download PASCAL VOC 2007 and 2012 datasets; and (3) Create LMDB file. Make sure you can run it without any errors. -1. Create a subfolder `dsod` under `example/`, add files `DSOD300_pascal.py`, `DSOD300_pascal++.py`, `DSOD300_coco.py` and `score_DSOD300_pascal.py` to the folder `example/dsod/`. +1. Create a subfolder `dsod` under `example/`, add files `DSOD300_pascal.py`, `DSOD300_pascal++.py`, `DSOD300_coco.py`, `score_DSOD300_pascal.py` and `DSOD300_detection_demo.py` to the folder `example/dsod/`. 2. Replace the file `model_libs.py` in the folder `python/caffe/` with ours. - ## Training & Testing - Train a DSOD model on VOC 07+12: @@ -86,12 +83,16 @@ COCO test-dev 2015 result (COCO has more object categories than VOC dataset, so python examples/dsod/DSOD300_coco.py ``` - - Evaluate the model: - ```shell python examples/dsod/score_DSOD300_pascal.py + ``` + +- Run a demo: + + ```shell + python examples/dsod/DSOD300_detection_demo.py ``` **Note**: You can modify the file `model_lib.py` to design your own network structure as you like. @@ -109,4 +110,3 @@ Zhiqiang Shen (zhiqiangshen13 at fudan.edu.cn) Zhuang Liu (liuzhuangthu at gmail.com) Any comments or suggestions are welcome! -