Skip to content

Commit

Permalink
Added --benchmark option and disabled gaze tracking for models with…
Browse files Browse the repository at this point in the history
… negative IDs.
  • Loading branch information
emilianavt committed Sep 26, 2020
1 parent 78bfd8c commit d8c8ed0
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
24 changes: 21 additions & 3 deletions facetracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
parser.add_argument("-p", "--port", type=int, help="Set port for sending tracking data", default=11573)
if os.name == 'nt':
parser.add_argument("-l", "--list-cameras", type=int, help="Set this to 1 to list the available cameras and quit, set this to 2 or higher to output only the names", default=0)
parser.add_argument("-a", "--list-dcaps", type=int, help="Set this to -1 to list all cameras and their available capabilities, set this to a camera id to list that camera's capabilities.", default=None)
parser.add_argument("-a", "--list-dcaps", type=int, help="Set this to -1 to list all cameras and their available capabilities, set this to a camera id to list that camera's capabilities", default=None)
parser.add_argument("-W", "--width", type=int, help="Set camera and raw RGB width", default=640)
parser.add_argument("-H", "--height", type=int, help="Set camera and raw RGB height", default=360)
parser.add_argument("-F", "--fps", type=int, help="Set camera frames per second", default=24)
Expand Down Expand Up @@ -43,6 +43,7 @@
parser.add_argument("--face-id-offset", type=int, help="When set, this offset is added to all face ids, which can be useful for mixing tracking data from multiple network sources", default=0)
parser.add_argument("--repeat-video", type=int, help="When set to 1 and a video file was specified with -c, the tracker will loop the video until interrupted", default=0)
parser.add_argument("--dump-points", type=str, help="When set to a filename, the current face 3D points are made symmetric and dumped to the given file when quitting the visualization with the \"q\" key", default="")
parser.add_argument("--benchmark", type=int, help="When set to 1, the different tracking models are benchmarked, starting with the best and ending with the fastest and with gaze tracking disabled for models with negative IDs", default=0)
if os.name == 'nt':
parser.add_argument("--use-dshowcapture", type=int, help="When set to 1, libdshowcapture will be used for video input instead of OpenCV", default=1)
args = parser.parse_args()
Expand Down Expand Up @@ -106,8 +107,25 @@ def flush(self):
import cv2
import socket
import struct
import json
from input_reader import InputReader, VideoReader, DShowCaptureReader, try_int
from tracker import Tracker
from tracker import Tracker, get_model_base_path

if args.benchmark > 0:
model_base_path = get_model_base_path(args.model_dir)
im = cv2.imread(os.path.join(model_base_path, "benchmark.bin"), cv2.IMREAD_COLOR)
results = []
for model_type in [3, 2, 1, 0, -1]:
tracker = Tracker(224, 224, threshold=0.1, max_threads=args.max_threads, max_faces=1, discard_after=0, scan_every=0, silent=True, model_type=model_type, model_dir=args.model_dir, no_gaze=(model_type < 0), detection_threshold=0.1, use_retinaface=0, max_feature_updates=900, static_model=True if args.no_3d_adapt == 1 else False)
tracker.detected = 1
tracker.faces = [(0, 0, 224, 224)]
total = 0.0
for i in range(100):
start = time.perf_counter()
r = tracker.predict(im)
total += time.perf_counter() - start
print(1. / (total / 100.))
sys.exit(0)

target_ip = args.ip
target_port = args.port
Expand Down Expand Up @@ -208,7 +226,7 @@ def flush(self):
first = False
height, width, channels = frame.shape
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
tracker = Tracker(width, height, threshold=args.threshold, max_threads=args.max_threads, max_faces=args.faces, discard_after=args.discard_after, scan_every=args.scan_every, silent=False if args.silent == 0 else True, model_type=args.model, model_dir=args.model_dir, no_gaze=False if args.gaze_tracking != 0 else True, detection_threshold=args.detection_threshold, use_retinaface=args.scan_retinaface, max_feature_updates=args.max_feature_updates, static_model=True if args.no_3d_adapt == 1 else False)
tracker = Tracker(width, height, threshold=args.threshold, max_threads=args.max_threads, max_faces=args.faces, discard_after=args.discard_after, scan_every=args.scan_every, silent=False if args.silent == 0 else True, model_type=args.model, model_dir=args.model_dir, no_gaze=False if args.gaze_tracking != 0 and args.model >= 0 else True, detection_threshold=args.detection_threshold, use_retinaface=args.scan_retinaface, max_feature_updates=args.max_feature_updates, static_model=True if args.no_3d_adapt == 1 else False)
if not args.video_out is None:
out = cv2.VideoWriter(args.video_out, cv2.VideoWriter_fourcc('F','F','V','1'), args.video_fps, (width * args.video_scale, height * args.video_scale))

Expand Down
Binary file added models/benchmark.bin
Binary file not shown.
16 changes: 10 additions & 6 deletions tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,15 @@ def adjust_3d(self):
self.eye_blink.append(1 - min(max(0, -self.current_features["eye_r"]), 1))
self.eye_blink.append(1 - min(max(0, -self.current_features["eye_l"]), 1))

def get_model_base_path(model_dir):
model_base_path = resolve(os.path.join("models"))
if model_dir is None:
if not os.path.exists(model_base_path):
model_base_path = resolve(os.path.join("..", "models"))
else:
model_base_path = model_dir
return model_base_path

class Tracker():
def __init__(self, width, height, model_type=3, detection_threshold=0.6, threshold=None, max_faces=1, discard_after=5, scan_every=3, bbox_growth=0.0, max_threads=4, silent=False, model_dir=None, no_gaze=False, use_retinaface=False, max_feature_updates=0, static_model=False, feature_level=2):
options = onnxruntime.SessionOptions()
Expand All @@ -502,12 +511,7 @@ def __init__(self, width, height, model_type=3, detection_threshold=0.6, thresho
model = "lm_modelT_opt.onnx"
if model_type >= 0:
model = self.models[self.model_type]
model_base_path = resolve(os.path.join("models"))
if model_dir is None:
if not os.path.exists(model_base_path):
model_base_path = resolve(os.path.join("..", "models"))
else:
model_base_path = model_dir
model_base_path = get_model_base_path(model_dir)

if threshold is None:
threshold = 0.6
Expand Down

0 comments on commit d8c8ed0

Please sign in to comment.