From ec414fd1f1b670dd21cc97a639b8ccaa6ff13eef Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Sat, 30 Mar 2019 17:24:09 +0530 Subject: [PATCH] Update object detection API --- cvlib/__init__.py | 3 ++- cvlib/gender_detection.py | 48 +++++++++++++++++++++++++++++++++++++++ cvlib/object_detection.py | 12 ++++------ setup.py | 2 +- 4 files changed, 56 insertions(+), 9 deletions(-) create mode 100644 cvlib/gender_detection.py diff --git a/cvlib/__init__.py b/cvlib/__init__.py index 427960e..da17beb 100644 --- a/cvlib/__init__.py +++ b/cvlib/__init__.py @@ -1,7 +1,8 @@ # author: Arun Ponnusamy # website: https://www.arunponnusamy.com -__version__ = "0.1.6" +__version__ = "0.1.7" from .face_detection import detect_face from .object_detection import detect_common_objects +from .gender_detection import detect_gender diff --git a/cvlib/gender_detection.py b/cvlib/gender_detection.py new file mode 100644 index 0000000..7765a1e --- /dev/null +++ b/cvlib/gender_detection.py @@ -0,0 +1,48 @@ +import cv2 +import os +import numpy as np +from keras.utils import get_file +from keras.models import load_model +from keras.preprocessing.image import img_to_array + +is_initialized = False +model = None + +def pre_process(face): + + face = cv2.resize(face, (96,96)) + face = face.astype("float") / 255.0 + face = img_to_array(face) + face = np.expand_dims(face, axis=0) + + return face + +def detect_gender(face): + + global is_initialized + global model + + labels = ['man', 'woman'] + + + if not is_initialized: + + print("[INFO] initializing ... ") + + dwnld_link = "https://s3.ap-south-1.amazonaws.com/arunponnusamy/pre-trained-weights/gender_detection.model" + + model_path = get_file("gender_detection.model", dwnld_link, + cache_dir= os.path.expanduser('~') + os.path.sep + '.cvlib' + os.path.sep + 'pre-trained') + + model = load_model(model_path) + + is_initialized = True + + + face = pre_process(face) + + conf = model.predict(face)[0] + + return (labels, conf) + + diff --git a/cvlib/object_detection.py b/cvlib/object_detection.py index 25da209..d9846f7 100644 --- a/cvlib/object_detection.py +++ b/cvlib/object_detection.py @@ -55,7 +55,7 @@ def draw_bbox(img, bbox, labels, confidence, colors=None, write_conf=False): return img -def detect_common_objects(image): +def detect_common_objects(image, confidence=0.5, nms_thresh=0.3): Height, Width = image.shape[:2] scale = 0.00392 @@ -96,15 +96,13 @@ def detect_common_objects(image): class_ids = [] confidences = [] boxes = [] - conf_threshold = 0.5 - nms_threshold = 0.4 for out in outs: for detection in out: scores = detection[5:] class_id = np.argmax(scores) - confidence = scores[class_id] - if confidence > 0.5: + max_conf = scores[class_id] + if max_conf > confidence: center_x = int(detection[0] * Width) center_y = int(detection[1] * Height) w = int(detection[2] * Width) @@ -112,11 +110,11 @@ def detect_common_objects(image): x = center_x - w / 2 y = center_y - h / 2 class_ids.append(class_id) - confidences.append(float(confidence)) + confidences.append(float(max_conf)) boxes.append([x, y, w, h]) - indices = cv2.dnn.NMSBoxes(boxes, confidences, conf_threshold, nms_threshold) + indices = cv2.dnn.NMSBoxes(boxes, confidences, confidence, nms_thresh) bbox = [] label = [] diff --git a/setup.py b/setup.py index 0eb5429..d3d2ce5 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup setup(name='cvlib', - version='0.1.6', + version='0.1.8', description='A high level, easy to use, open source computer vision library for python', url='https://github.com/arunponnusamy/cvlib.git', author='Arun Ponnusamy',