diff --git a/README.md b/README.md index 0366909..303e21d 100644 --- a/README.md +++ b/README.md @@ -15,11 +15,12 @@ ### 简介 -主要是做文档类图像的版面分析。具体来说,就是分析给定的文档类别图像(论文截图、研报等),定位其中类别和位置,如标题、段落、表格和图片等各个部分。 + +该项目主要是汇集全网开源的版面分析的项目,具体来说,就是分析给定的文档类别图像(论文截图、研报等),定位其中类别和位置,如标题、段落、表格和图片等各个部分。 ⚠️注意:需要说明的是,由于不同场景下的版面差异较大,现阶段不存在一个模型可以搞定所有场景。如果实际业务需要,以下模型效果不好的话,建议构建自己的训练集微调。 -目前支持以下场景的版面分析: +目前支持已经支持的版面分析模型如下: |`model_type`| 版面类型 | 模型名称 | 支持类别| | :------ | :----- | :------ | :----- | @@ -30,30 +31,38 @@ | `yolov8n_layout_report`| 研报 | `yolov8n_layout_report.onnx` | `['Text', 'Title', 'Header', 'Footer', 'Figure', 'Table', 'Toc', 'Figure caption', 'Table caption']` | | `yolov8n_layout_publaynet`| 英文 | `yolov8n_layout_publaynet.onnx` | `["Text", "Title", "List", "Table", "Figure"]` | | `yolov8n_layout_general6`| 通用 | `yolov8n_layout_general6.onnx` | `["Text", "Title", "Figure", "Table", "Caption", "Equation"]` | +| 🔥`doclayout_yolo`| 通用 | `doclayout_yolo_docstructbench_imgsz1024.onnx` | `['title', 'text', 'abandon', 'figure', 'figure_caption', 'table', 'table_caption', 'table_footnote', 'isolate_formula', 'formula_caption']` | PP模型来源:[PaddleOCR 版面分析](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/layout/README_ch.md) yolov8n系列来源:[360LayoutAnalysis](https://github.com/360AILAB-NLP/360LayoutAnalysis) +(推荐使用)🔥doclayout_yolo模型来源:[DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO),该模型是目前最为优秀的开源模型,支持学术论文、Textbook、Financial、Exam Paper、Fuzzy Scans、PPT和Poster 7种文档类型的版面检测。值得一提的是,该模型支持的类别中存在`abandon`一类,主要是文档页面的页眉页脚部分,便于后续快速舍弃。 + 模型下载地址为:[link](https://github.com/RapidAI/RapidLayout/releases/tag/v0.0.0) ### 安装 + 由于模型较小,预先将中文版面分析模型(`layout_cdla.onnx`)打包进了whl包内,如果做中文版面分析,可直接安装使用 ```bash -$ pip install rapid-layout +pip install rapid-layout ``` ### 使用方式 + #### python脚本运行 + ```python import cv2 + from rapid_layout import RapidLayout, VisLayout # model_type类型参见上表。指定不同model_type时,会自动下载相应模型到安装目录下的。 -layout_engine = RapidLayout(conf_thres=0.5, model_type="pp_layout_cdla") +layout_engine = RapidLayout(model_type="doclayout_yolo", conf_thres=0.2) -img = cv2.imread('test_images/layout.png') +img_path = "tests/test_files/financial.jpg" +img = cv2.imread(img_path) boxes, scores, class_names, elapse = layout_engine(img) ploted_img = VisLayout.draw_detections(img, boxes, scores, class_names) @@ -61,41 +70,51 @@ if ploted_img is not None: cv2.imwrite("layout_res.png", ploted_img) ``` +### 可视化结果 + +
+ +
+ #### 终端运行 + ```bash $ rapid_layout -h usage: rapid_layout [-h] -img IMG_PATH - [-m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6}] - [--conf_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6}] - [--iou_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6}] - [--use_cuda] [--use_dml] [-v] + [-m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6,doclayout_yolo}] + [--conf_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6,doclayout_yolo}] + [--iou_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6,doclayout_yolo}] + [--use_cuda] [--use_dml] [-v] options: -h, --help show this help message and exit -img IMG_PATH, --img_path IMG_PATH Path to image for layout. - -m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6}, --model_type {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6} + -m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6,doclayout_yolo}, --model_type {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6,doclayout_yolo} Support model type - --conf_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6} + --conf_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6,doclayout_yolo} Box threshold, the range is [0, 1] - --iou_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6} + --iou_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6,doclayout_yolo} IoU threshold, the range is [0, 1] --use_cuda Whether to use cuda. --use_dml Whether to use DirectML, which only works in Windows10+. -v, --vis Wheter to visualize the layout results. ``` + - 示例: + ```bash - $ rapid_layout -v -img test_images/layout.png + rapid_layout -v -img test_images/layout.png ``` - ### GPU推理 + - 因为版面分析模型输入图像尺寸固定,故可使用`onnxruntime-gpu`来提速。 - 因为`rapid_layout`库默认依赖是CPU版`onnxruntime`,如果想要使用GPU推理,需要手动安装`onnxruntime-gpu`。 - 详细使用和评测可参见[AI Studio](https://aistudio.baidu.com/projectdetail/8094594) #### 安装 + ```bash pip install rapid_layout pip uninstall onnxruntime @@ -106,13 +125,14 @@ pip install onnxruntime-gpu ``` #### 使用 + ```python import cv2 from rapid_layout import RapidLayout from pathlib import Path # 注意:这里需要使用use_cuda指定参数 -layout_engine = RapidLayout(conf_thres=0.5, model_type="pp_layout_cdla", use_cuda=True) +layout_engine = RapidLayout(model_type="doclayout_yolo", conf_thres=0.2, use_cuda=True) # warm up layout_engine("images/12027_5.png") @@ -128,15 +148,10 @@ avg_elapse = sum(elapses) / len(elapses) print(f'avg elapse: {avg_elapse:.4f}') ``` -### 可视化结果 - -
- -
- - ### 参考项目 + +- [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO) - [PP-Structure](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/layout/README_ch.md) - [360LayoutAnalysis](https://github.com/360AILAB-NLP/360LayoutAnalysis) - [ONNX-YOLOv8-Object-Detection](https://github.com/ibaiGorordo/ONNX-YOLOv8-Object-Detection) -- [ChineseDocumentPDF](https://github.com/SWHL/ChineseDocumentPDF) \ No newline at end of file +- [ChineseDocumentPDF](https://github.com/SWHL/ChineseDocumentPDF) diff --git a/demo.py b/demo.py index 7e47d5a..bf398c8 100644 --- a/demo.py +++ b/demo.py @@ -5,9 +5,9 @@ from rapid_layout import RapidLayout, VisLayout -layout_engine = RapidLayout(model_type="yolov8n_layout_paper") +layout_engine = RapidLayout(model_type="doclayout_yolo", conf_thres=0.1) -img_path = "tests/test_files/layout.png" +img_path = "tests/test_files/PMC3576793_00004.jpg" img = cv2.imread(img_path) boxes, scores, class_names, elapse = layout_engine(img) diff --git a/rapid_layout/main.py b/rapid_layout/main.py index c17f06a..84125d4 100644 --- a/rapid_layout/main.py +++ b/rapid_layout/main.py @@ -10,6 +10,8 @@ import numpy as np from .utils import ( + DocLayoutPostProcess, + DocLayoutPreProcess, DownloadModel, LoadImage, OrtInferSession, @@ -33,6 +35,7 @@ "yolov8n_layout_report": f"{ROOT_URL}/yolov8n_layout_report.onnx", "yolov8n_layout_publaynet": f"{ROOT_URL}/yolov8n_layout_publaynet.onnx", "yolov8n_layout_general6": f"{ROOT_URL}/yolov8n_layout_general6.onnx", + "doclayout_yolo": f"{ROOT_URL}/doclayout_yolo_docstructbench_imgsz1024.onnx", } DEFAULT_MODEL_PATH = str(ROOT_DIR / "models" / "layout_cdla.onnx") @@ -72,12 +75,20 @@ def __init__( self.yolov8_preprocess = YOLOv8PreProcess(img_size=self.yolov8_input_shape) self.yolov8_postprocess = YOLOv8PostProcess(labels, conf_thres, iou_thres) + # doclayout + self.doclayout_input_shape = (1024, 1024) + self.doclayout_preprocess = DocLayoutPreProcess( + img_size=self.doclayout_input_shape + ) + self.doclayout_postprocess = DocLayoutPostProcess(labels, conf_thres, iou_thres) + self.load_img = LoadImage() self.pp_layout_type = [k for k in KEY_TO_MODEL_URL if k.startswith("pp")] self.yolov8_layout_type = [ k for k in KEY_TO_MODEL_URL if k.startswith("yolov8n") ] + self.doclayout_type = [k for k in KEY_TO_MODEL_URL if k.startswith("doclayout")] def __call__( self, img_content: Union[str, np.ndarray, bytes, Path] @@ -91,6 +102,9 @@ def __call__( if self.model_type in self.yolov8_layout_type: return self.yolov8_layout(img, ori_img_shape) + if self.model_type in self.doclayout_type: + return self.doclayout_layout(img, ori_img_shape) + raise ValueError(f"{self.model_type} is not supported.") def pp_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]): @@ -114,6 +128,17 @@ def yolov8_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]): elapse = time.time() - s_time return boxes, scores, class_names, elapse + def doclayout_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]): + s_time = time.time() + + input_tensor = self.doclayout_preprocess(img) + outputs = self.session(input_tensor) + boxes, scores, class_names = self.doclayout_postprocess( + outputs, ori_img_shape, self.doclayout_input_shape + ) + elapse = time.time() - s_time + return boxes, scores, class_names, elapse + @staticmethod def get_model_path(model_type: str, model_path: Union[str, Path, None]) -> str: if model_path is not None: diff --git a/rapid_layout/utils/__init__.py b/rapid_layout/utils/__init__.py index ed27313..f9b7953 100644 --- a/rapid_layout/utils/__init__.py +++ b/rapid_layout/utils/__init__.py @@ -5,6 +5,6 @@ from .infer_engine import OrtInferSession from .load_image import LoadImage, LoadImageError from .logger import get_logger -from .post_prepross import PPPostProcess, YOLOv8PostProcess -from .pre_procss import PPPreProcess, YOLOv8PreProcess +from .post_prepross import DocLayoutPostProcess, PPPostProcess, YOLOv8PostProcess +from .pre_procss import DocLayoutPreProcess, PPPreProcess, YOLOv8PreProcess from .vis_res import VisLayout diff --git a/rapid_layout/utils/post_prepross.py b/rapid_layout/utils/post_prepross.py index 516fc33..bba6916 100644 --- a/rapid_layout/utils/post_prepross.py +++ b/rapid_layout/utils/post_prepross.py @@ -288,23 +288,61 @@ def extract_boxes(self, predictions): boxes = predictions[:, :4] # Scale boxes to original image dimensions - boxes = self.rescale_boxes(boxes) + boxes = rescale_boxes( + boxes, self.input_width, self.input_height, self.img_width, self.img_height + ) # Convert boxes to xyxy format boxes = xywh2xyxy(boxes) return boxes - def rescale_boxes(self, boxes): + +class DocLayoutPostProcess: + def __init__(self, labels: List[str], conf_thres=0.7, iou_thres=0.5): + self.labels = labels + self.conf_threshold = conf_thres + self.iou_threshold = iou_thres + self.input_width, self.input_height = None, None + self.img_width, self.img_height = None, None + + def __call__( + self, + output, + ori_img_shape: Tuple[int, int], + img_shape: Tuple[int, int] = (1024, 1024), + ): + self.img_height, self.img_width = ori_img_shape + self.input_height, self.input_width = img_shape + + output = output[0].squeeze() + boxes = output[:, :-2] + confidences = output[:, -2] + class_ids = output[:, -1].astype(int) + + mask = confidences > self.conf_threshold + boxes = boxes[mask, :] + confidences = confidences[mask] + class_ids = class_ids[mask] + # Rescale boxes to original image dimensions - input_shape = np.array( - [self.input_width, self.input_height, self.input_width, self.input_height] - ) - boxes = np.divide(boxes, input_shape, dtype=np.float32) - boxes *= np.array( - [self.img_width, self.img_height, self.img_width, self.img_height] + boxes = rescale_boxes( + boxes, + self.input_width, + self.input_height, + self.img_width, + self.img_height, ) - return boxes + labels = [self.labels[i] for i in class_ids] + return boxes, confidences, labels + + +def rescale_boxes(boxes, input_width, input_height, img_width, img_height): + # Rescale boxes to original image dimensions + input_shape = np.array([input_width, input_height, input_width, input_height]) + boxes = np.divide(boxes, input_shape, dtype=np.float32) + boxes *= np.array([img_width, img_height, img_width, img_height]) + return boxes def nms(boxes, scores, iou_threshold): diff --git a/rapid_layout/utils/pre_procss.py b/rapid_layout/utils/pre_procss.py index 78ce748..0e8727a 100644 --- a/rapid_layout/utils/pre_procss.py +++ b/rapid_layout/utils/pre_procss.py @@ -51,3 +51,17 @@ def __call__(self, image: np.ndarray) -> np.ndarray: input_img = input_img.transpose(2, 0, 1) input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32) return input_tensor + + +class DocLayoutPreProcess: + + def __init__(self, img_size: Tuple[int, int]): + self.img_size = img_size + + def __call__(self, image: np.ndarray) -> np.ndarray: + input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + input_img = cv2.resize(image, self.img_size) + input_img = input_img / 255.0 + input_img = input_img.transpose(2, 0, 1) + input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32) + return input_tensor diff --git a/tests/test_layout.py b/tests/test_layout.py index f7f507e..b40713e 100644 --- a/tests/test_layout.py +++ b/tests/test_layout.py @@ -22,7 +22,12 @@ @pytest.mark.parametrize( - "model_type,gt", [("yolov8n_layout_publaynet", 12), ("yolov8n_layout_general6", 13)] + "model_type,gt", + [ + ("yolov8n_layout_publaynet", 12), + ("yolov8n_layout_general6", 13), + ("doclayout_yolo", 14), + ], ) def test_yolov8n_layout(model_type, gt): img_path = test_file_dir / "PMC3576793_00004.jpg"