Skip to content

Commit

Permalink
implement retina AI for face recognition
Browse files Browse the repository at this point in the history
  • Loading branch information
glyh committed Mar 9, 2024
1 parent f12f5a9 commit 8e10e3c
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 1 deletion.
10 changes: 9 additions & 1 deletion JavSP.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


from core.print import TqdmOut
from core.baidu_aip import aip_crop_poster


# 将StreamHandler的stream修改为TqdmOut,以与Tqdm协同工作
Expand Down Expand Up @@ -421,12 +420,21 @@ def reviewMovieID(all_movies, root):
def crop_poster_wrapper(fanart_file, poster_file, method='normal'):
"""包装各种海报裁剪方法,提供统一的调用"""
if method == 'baidu':
from core.ai_crop.baidu_aip import aip_crop_poster
try:
aip_crop_poster(fanart_file, poster_file)
except Exception as e:
logger.debug('人脸识别失败,回退到常规裁剪方法')
logger.debug(e, exc_info=True)
crop_poster(fanart_file, poster_file)
elif method == 'retina':
from core.ai_crop.retina import ai_crop_poster
try:
ai_crop_poster(fanart_file, poster_file)
except Exception as e:
logger.debug('人脸识别失败,回退到常规裁剪方法')
logger.debug(e, exc_info=True)
crop_poster(fanart_file, poster_file)
else:
crop_poster(fanart_file, poster_file)

Expand Down
File renamed without changes.
25 changes: 25 additions & 0 deletions core/ai_crop/retina.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from retinaface import RetinaFace

from PIL import Image, ImageOps
def ai_crop_poster(fanart, poster='', hw_ratio=1.42):
im = ImageOps.exif_transpose(Image.open(fanart))
fanart_w, fanart_h = im.size
poster_h = fanart_h
poster_w = fanart_h / hw_ratio

resp = RetinaFace.detect_faces(fanart)

if not 'face_1' in resp:
raise Exception("Retina can't detect face")

[x1, y1, x2, y2] = resp['face_1']['facial_area']
center_x = (x1 + x2) / 2
center_y = (y1 + y2) / 2
poster_left = max(center_x - poster_w / 2, 0)
poster_left = min(poster_left, fanart_w - poster_w)
poster_left = int(poster_left)
im_poster = im.crop((poster_left, 0, int(poster_left + poster_w), poster_h))
if im_poster.mode != 'RGB':
im_poster = im_poster.convert('RGB')
im_poster.save(poster, quality=95)

2 changes: 2 additions & 0 deletions core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ def validate_ai_config(cfg: Config):
empty_keys = [i for i in required_keys if not piccfg[i]]
if empty_keys:
logger.error('使用百度人体分析时,相关设置不能为空: ' + ', '.join(empty_keys))
elif piccfg.ai_engine.lower() == 'retina':
piccfg.ai_engine = 'retina'
else:
logger.error('不支持的图像识别引擎: ' + piccfg.ai_engine)

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ requests==2.31.0
tqdm==4.59.0
urllib3==1.25.11
cryptography==42.0.4
retina-face==0.0.14

0 comments on commit 8e10e3c

Please sign in to comment.