diff --git a/JavSP.py b/JavSP.py index 87550d119..6edc047e1 100644 --- a/JavSP.py +++ b/JavSP.py @@ -20,7 +20,6 @@ from core.print import TqdmOut -from core.baidu_aip import aip_crop_poster # 将StreamHandler的stream修改为TqdmOut,以与Tqdm协同工作 @@ -421,12 +420,21 @@ def reviewMovieID(all_movies, root): def crop_poster_wrapper(fanart_file, poster_file, method='normal'): """包装各种海报裁剪方法,提供统一的调用""" if method == 'baidu': + from core.ai_crop.baidu_aip import aip_crop_poster try: aip_crop_poster(fanart_file, poster_file) except Exception as e: logger.debug('人脸识别失败,回退到常规裁剪方法') logger.debug(e, exc_info=True) crop_poster(fanart_file, poster_file) + elif method == 'retina': + from core.ai_crop.retina import ai_crop_poster + try: + ai_crop_poster(fanart_file, poster_file) + except Exception as e: + logger.debug('人脸识别失败,回退到常规裁剪方法') + logger.debug(e, exc_info=True) + crop_poster(fanart_file, poster_file) else: crop_poster(fanart_file, poster_file) diff --git a/core/baidu_aip.py b/core/ai_crop/baidu_aip.py similarity index 100% rename from core/baidu_aip.py rename to core/ai_crop/baidu_aip.py diff --git a/core/ai_crop/retina.py b/core/ai_crop/retina.py new file mode 100644 index 000000000..44257d0ac --- /dev/null +++ b/core/ai_crop/retina.py @@ -0,0 +1,25 @@ +from retinaface import RetinaFace + +from PIL import Image, ImageOps +def ai_crop_poster(fanart, poster='', hw_ratio=1.42): + im = ImageOps.exif_transpose(Image.open(fanart)) + fanart_w, fanart_h = im.size + poster_h = fanart_h + poster_w = fanart_h / hw_ratio + + resp = RetinaFace.detect_faces(fanart) + + if not 'face_1' in resp: + raise Exception("Retina can't detect face") + + [x1, y1, x2, y2] = resp['face_1']['facial_area'] + center_x = (x1 + x2) / 2 + center_y = (y1 + y2) / 2 + poster_left = max(center_x - poster_w / 2, 0) + poster_left = min(poster_left, fanart_w - poster_w) + poster_left = int(poster_left) + im_poster = im.crop((poster_left, 0, int(poster_left + poster_w), poster_h)) + if im_poster.mode != 'RGB': + im_poster = im_poster.convert('RGB') + im_poster.save(poster, quality=95) + diff --git a/core/config.py b/core/config.py index 48a0ebcf1..f8866b92b 100644 --- a/core/config.py +++ b/core/config.py @@ -341,6 +341,8 @@ def validate_ai_config(cfg: Config): empty_keys = [i for i in required_keys if not piccfg[i]] if empty_keys: logger.error('使用百度人体分析时,相关设置不能为空: ' + ', '.join(empty_keys)) + elif piccfg.ai_engine.lower() == 'retina': + piccfg.ai_engine = 'retina' else: logger.error('不支持的图像识别引擎: ' + piccfg.ai_engine) diff --git a/requirements.txt b/requirements.txt index a25d78421..06322d9d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,4 @@ requests==2.31.0 tqdm==4.59.0 urllib3==1.25.11 cryptography==42.0.4 +retina-face==0.0.14