diff --git a/facetracker.py b/facetracker.py index 3fa15a4..ed4c7e4 100644 --- a/facetracker.py +++ b/facetracker.py @@ -10,7 +10,7 @@ parser.add_argument("-p", "--port", type=int, help="Set port for sending tracking data", default=11573) if os.name == 'nt': parser.add_argument("-l", "--list-cameras", type=int, help="Set this to 1 to list the available cameras and quit, set this to 2 or higher to output only the names", default=0) - parser.add_argument("-a", "--list-dcaps", type=int, help="Set this to -1 to list all cameras and their available capabilities, set this to a camera id to list that camera's capabilities.", default=None) + parser.add_argument("-a", "--list-dcaps", type=int, help="Set this to -1 to list all cameras and their available capabilities, set this to a camera id to list that camera's capabilities", default=None) parser.add_argument("-W", "--width", type=int, help="Set camera and raw RGB width", default=640) parser.add_argument("-H", "--height", type=int, help="Set camera and raw RGB height", default=360) parser.add_argument("-F", "--fps", type=int, help="Set camera frames per second", default=24) @@ -43,6 +43,7 @@ parser.add_argument("--face-id-offset", type=int, help="When set, this offset is added to all face ids, which can be useful for mixing tracking data from multiple network sources", default=0) parser.add_argument("--repeat-video", type=int, help="When set to 1 and a video file was specified with -c, the tracker will loop the video until interrupted", default=0) parser.add_argument("--dump-points", type=str, help="When set to a filename, the current face 3D points are made symmetric and dumped to the given file when quitting the visualization with the \"q\" key", default="") +parser.add_argument("--benchmark", type=int, help="When set to 1, the different tracking models are benchmarked, starting with the best and ending with the fastest and with gaze tracking disabled for models with negative IDs", default=0) if os.name == 'nt': parser.add_argument("--use-dshowcapture", type=int, help="When set to 1, libdshowcapture will be used for video input instead of OpenCV", default=1) args = parser.parse_args() @@ -106,8 +107,25 @@ def flush(self): import cv2 import socket import struct +import json from input_reader import InputReader, VideoReader, DShowCaptureReader, try_int -from tracker import Tracker +from tracker import Tracker, get_model_base_path + +if args.benchmark > 0: + model_base_path = get_model_base_path(args.model_dir) + im = cv2.imread(os.path.join(model_base_path, "benchmark.bin"), cv2.IMREAD_COLOR) + results = [] + for model_type in [3, 2, 1, 0, -1]: + tracker = Tracker(224, 224, threshold=0.1, max_threads=args.max_threads, max_faces=1, discard_after=0, scan_every=0, silent=True, model_type=model_type, model_dir=args.model_dir, no_gaze=(model_type < 0), detection_threshold=0.1, use_retinaface=0, max_feature_updates=900, static_model=True if args.no_3d_adapt == 1 else False) + tracker.detected = 1 + tracker.faces = [(0, 0, 224, 224)] + total = 0.0 + for i in range(100): + start = time.perf_counter() + r = tracker.predict(im) + total += time.perf_counter() - start + print(1. / (total / 100.)) + sys.exit(0) target_ip = args.ip target_port = args.port @@ -208,7 +226,7 @@ def flush(self): first = False height, width, channels = frame.shape sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - tracker = Tracker(width, height, threshold=args.threshold, max_threads=args.max_threads, max_faces=args.faces, discard_after=args.discard_after, scan_every=args.scan_every, silent=False if args.silent == 0 else True, model_type=args.model, model_dir=args.model_dir, no_gaze=False if args.gaze_tracking != 0 else True, detection_threshold=args.detection_threshold, use_retinaface=args.scan_retinaface, max_feature_updates=args.max_feature_updates, static_model=True if args.no_3d_adapt == 1 else False) + tracker = Tracker(width, height, threshold=args.threshold, max_threads=args.max_threads, max_faces=args.faces, discard_after=args.discard_after, scan_every=args.scan_every, silent=False if args.silent == 0 else True, model_type=args.model, model_dir=args.model_dir, no_gaze=False if args.gaze_tracking != 0 and args.model >= 0 else True, detection_threshold=args.detection_threshold, use_retinaface=args.scan_retinaface, max_feature_updates=args.max_feature_updates, static_model=True if args.no_3d_adapt == 1 else False) if not args.video_out is None: out = cv2.VideoWriter(args.video_out, cv2.VideoWriter_fourcc('F','F','V','1'), args.video_fps, (width * args.video_scale, height * args.video_scale)) diff --git a/models/benchmark.bin b/models/benchmark.bin new file mode 100644 index 0000000..04fb818 Binary files /dev/null and b/models/benchmark.bin differ diff --git a/tracker.py b/tracker.py index 0a8d296..edbe8a4 100644 --- a/tracker.py +++ b/tracker.py @@ -484,6 +484,15 @@ def adjust_3d(self): self.eye_blink.append(1 - min(max(0, -self.current_features["eye_r"]), 1)) self.eye_blink.append(1 - min(max(0, -self.current_features["eye_l"]), 1)) +def get_model_base_path(model_dir): + model_base_path = resolve(os.path.join("models")) + if model_dir is None: + if not os.path.exists(model_base_path): + model_base_path = resolve(os.path.join("..", "models")) + else: + model_base_path = model_dir + return model_base_path + class Tracker(): def __init__(self, width, height, model_type=3, detection_threshold=0.6, threshold=None, max_faces=1, discard_after=5, scan_every=3, bbox_growth=0.0, max_threads=4, silent=False, model_dir=None, no_gaze=False, use_retinaface=False, max_feature_updates=0, static_model=False, feature_level=2): options = onnxruntime.SessionOptions() @@ -502,12 +511,7 @@ def __init__(self, width, height, model_type=3, detection_threshold=0.6, thresho model = "lm_modelT_opt.onnx" if model_type >= 0: model = self.models[self.model_type] - model_base_path = resolve(os.path.join("models")) - if model_dir is None: - if not os.path.exists(model_base_path): - model_base_path = resolve(os.path.join("..", "models")) - else: - model_base_path = model_dir + model_base_path = get_model_base_path(model_dir) if threshold is None: threshold = 0.6