From 133e0b42a3a6423ea3eb9aa0607aee91c357dee6 Mon Sep 17 00:00:00 2001
From: SWHL <liekkaskono@163.com>
Date: Fri, 18 Oct 2024 21:41:08 +0800
Subject: [PATCH] feat: Support DocLayout YOLO model

---
 README.md                           | 61 ++++++++++++++++++-----------
 demo.py                             |  4 +-
 rapid_layout/main.py                | 25 ++++++++++++
 rapid_layout/utils/__init__.py      |  4 +-
 rapid_layout/utils/post_prepross.py | 56 +++++++++++++++++++++-----
 rapid_layout/utils/pre_procss.py    | 14 +++++++
 tests/test_layout.py                |  7 +++-
 7 files changed, 134 insertions(+), 37 deletions(-)

diff --git a/README.md b/README.md
index 0366909..303e21d 100644
--- a/README.md
+++ b/README.md
@@ -15,11 +15,12 @@
 </div>
 
 ### 简介
-主要是做文档类图像的版面分析。具体来说,就是分析给定的文档类别图像(论文截图、研报等),定位其中类别和位置,如标题、段落、表格和图片等各个部分。
+
+该项目主要是汇集全网开源的版面分析的项目,具体来说,就是分析给定的文档类别图像(论文截图、研报等),定位其中类别和位置,如标题、段落、表格和图片等各个部分。
 
 ⚠️注意:需要说明的是,由于不同场景下的版面差异较大,现阶段不存在一个模型可以搞定所有场景。如果实际业务需要,以下模型效果不好的话,建议构建自己的训练集微调。
 
-目前支持以下场景的版面分析:
+目前支持已经支持的版面分析模型如下:
 
 |`model_type`| 版面类型 |        模型名称          |  支持类别|
 | :------ | :----- | :------ | :----- |
@@ -30,30 +31,38 @@
 | `yolov8n_layout_report`|   研报   |   `yolov8n_layout_report.onnx`    | `['Text', 'Title', 'Header', 'Footer', 'Figure', 'Table', 'Toc', 'Figure caption', 'Table caption']` |
 | `yolov8n_layout_publaynet`|   英文   |   `yolov8n_layout_publaynet.onnx`    | `["Text", "Title", "List", "Table", "Figure"]` |
 | `yolov8n_layout_general6`|   通用   |   `yolov8n_layout_general6.onnx`    | `["Text", "Title", "Figure", "Table", "Caption", "Equation"]` |
+| 🔥`doclayout_yolo`|   通用   |   `doclayout_yolo_docstructbench_imgsz1024.onnx`    | `['title', 'text', 'abandon', 'figure', 'figure_caption', 'table', 'table_caption', 'table_footnote', 'isolate_formula', 'formula_caption']` |
 
 PP模型来源:[PaddleOCR 版面分析](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/layout/README_ch.md)
 
 yolov8n系列来源:[360LayoutAnalysis](https://github.com/360AILAB-NLP/360LayoutAnalysis)
 
+(推荐使用)🔥doclayout_yolo模型来源:[DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO),该模型是目前最为优秀的开源模型,支持学术论文、Textbook、Financial、Exam Paper、Fuzzy Scans、PPT和Poster 7种文档类型的版面检测。值得一提的是,该模型支持的类别中存在`abandon`一类,主要是文档页面的页眉页脚部分,便于后续快速舍弃。
+
 模型下载地址为:[link](https://github.com/RapidAI/RapidLayout/releases/tag/v0.0.0)
 
 ### 安装
+
 由于模型较小,预先将中文版面分析模型(`layout_cdla.onnx`)打包进了whl包内,如果做中文版面分析,可直接安装使用
 
 ```bash
-$ pip install rapid-layout
+pip install rapid-layout
 ```
 
 ### 使用方式
+
 #### python脚本运行
+
 ```python
 import cv2
+
 from rapid_layout import RapidLayout, VisLayout
 
 # model_type类型参见上表。指定不同model_type时,会自动下载相应模型到安装目录下的。
-layout_engine = RapidLayout(conf_thres=0.5, model_type="pp_layout_cdla")
+layout_engine = RapidLayout(model_type="doclayout_yolo", conf_thres=0.2)
 
-img = cv2.imread('test_images/layout.png')
+img_path = "tests/test_files/financial.jpg"
+img = cv2.imread(img_path)
 
 boxes, scores, class_names, elapse = layout_engine(img)
 ploted_img = VisLayout.draw_detections(img, boxes, scores, class_names)
@@ -61,41 +70,51 @@ if ploted_img is not None:
     cv2.imwrite("layout_res.png", ploted_img)
 ```
 
+### 可视化结果
+
+<div align="center">
+    <img src="https://github.com/RapidAI/RapidLayout/releases/download/v0.0.0/layout_res.png" width="80%" height="80%">
+</div>
+
 #### 终端运行
+
 ```bash
 $ rapid_layout -h
 usage: rapid_layout [-h] -img IMG_PATH
-                [-m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6}]
-                [--conf_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6}]
-                [--iou_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6}]
-                [--use_cuda] [--use_dml] [-v]
+                    [-m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6,doclayout_yolo}]
+                    [--conf_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6,doclayout_yolo}]
+                    [--iou_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6,doclayout_yolo}]
+                    [--use_cuda] [--use_dml] [-v]
 
 options:
   -h, --help            show this help message and exit
   -img IMG_PATH, --img_path IMG_PATH
                         Path to image for layout.
-  -m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6}, --model_type {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6}
+  -m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6,doclayout_yolo}, --model_type {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6,doclayout_yolo}
                         Support model type
-  --conf_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6}
+  --conf_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6,doclayout_yolo}
                         Box threshold, the range is [0, 1]
-  --iou_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6}
+  --iou_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report,yolov8n_layout_publaynet,yolov8n_layout_general6,doclayout_yolo}
                         IoU threshold, the range is [0, 1]
   --use_cuda            Whether to use cuda.
   --use_dml             Whether to use DirectML, which only works in Windows10+.
   -v, --vis             Wheter to visualize the layout results.
 ```
+
 - 示例:
+
     ```bash
-    $ rapid_layout -v -img test_images/layout.png
+    rapid_layout -v -img test_images/layout.png
     ```
 
-
 ### GPU推理
+
 - 因为版面分析模型输入图像尺寸固定,故可使用`onnxruntime-gpu`来提速。
 - 因为`rapid_layout`库默认依赖是CPU版`onnxruntime`,如果想要使用GPU推理,需要手动安装`onnxruntime-gpu`。
 - 详细使用和评测可参见[AI Studio](https://aistudio.baidu.com/projectdetail/8094594)
 
 #### 安装
+
 ```bash
 pip install rapid_layout
 pip uninstall onnxruntime
@@ -106,13 +125,14 @@ pip install onnxruntime-gpu
 ```
 
 #### 使用
+
 ```python
 import cv2
 from rapid_layout import RapidLayout
 from pathlib import Path
 
 # 注意:这里需要使用use_cuda指定参数
-layout_engine = RapidLayout(conf_thres=0.5, model_type="pp_layout_cdla", use_cuda=True)
+layout_engine = RapidLayout(model_type="doclayout_yolo", conf_thres=0.2, use_cuda=True)
 
 # warm up
 layout_engine("images/12027_5.png")
@@ -128,15 +148,10 @@ avg_elapse = sum(elapses) / len(elapses)
 print(f'avg elapse: {avg_elapse:.4f}')
 ```
 
-### 可视化结果
-
-<div align="center">
-    <img src="https://github.com/RapidAI/RapidLayout/releases/download/v0.0.0/layout_res.png" width="80%" height="80%">
-</div>
-
-
 ### 参考项目
+
+- [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
 - [PP-Structure](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/layout/README_ch.md)
 - [360LayoutAnalysis](https://github.com/360AILAB-NLP/360LayoutAnalysis)
 - [ONNX-YOLOv8-Object-Detection](https://github.com/ibaiGorordo/ONNX-YOLOv8-Object-Detection)
-- [ChineseDocumentPDF](https://github.com/SWHL/ChineseDocumentPDF)
\ No newline at end of file
+- [ChineseDocumentPDF](https://github.com/SWHL/ChineseDocumentPDF)
diff --git a/demo.py b/demo.py
index 7e47d5a..bf398c8 100644
--- a/demo.py
+++ b/demo.py
@@ -5,9 +5,9 @@
 
 from rapid_layout import RapidLayout, VisLayout
 
-layout_engine = RapidLayout(model_type="yolov8n_layout_paper")
+layout_engine = RapidLayout(model_type="doclayout_yolo", conf_thres=0.1)
 
-img_path = "tests/test_files/layout.png"
+img_path = "tests/test_files/PMC3576793_00004.jpg"
 img = cv2.imread(img_path)
 
 boxes, scores, class_names, elapse = layout_engine(img)
diff --git a/rapid_layout/main.py b/rapid_layout/main.py
index c17f06a..84125d4 100644
--- a/rapid_layout/main.py
+++ b/rapid_layout/main.py
@@ -10,6 +10,8 @@
 import numpy as np
 
 from .utils import (
+    DocLayoutPostProcess,
+    DocLayoutPreProcess,
     DownloadModel,
     LoadImage,
     OrtInferSession,
@@ -33,6 +35,7 @@
     "yolov8n_layout_report": f"{ROOT_URL}/yolov8n_layout_report.onnx",
     "yolov8n_layout_publaynet": f"{ROOT_URL}/yolov8n_layout_publaynet.onnx",
     "yolov8n_layout_general6": f"{ROOT_URL}/yolov8n_layout_general6.onnx",
+    "doclayout_yolo": f"{ROOT_URL}/doclayout_yolo_docstructbench_imgsz1024.onnx",
 }
 DEFAULT_MODEL_PATH = str(ROOT_DIR / "models" / "layout_cdla.onnx")
 
@@ -72,12 +75,20 @@ def __init__(
         self.yolov8_preprocess = YOLOv8PreProcess(img_size=self.yolov8_input_shape)
         self.yolov8_postprocess = YOLOv8PostProcess(labels, conf_thres, iou_thres)
 
+        # doclayout
+        self.doclayout_input_shape = (1024, 1024)
+        self.doclayout_preprocess = DocLayoutPreProcess(
+            img_size=self.doclayout_input_shape
+        )
+        self.doclayout_postprocess = DocLayoutPostProcess(labels, conf_thres, iou_thres)
+
         self.load_img = LoadImage()
 
         self.pp_layout_type = [k for k in KEY_TO_MODEL_URL if k.startswith("pp")]
         self.yolov8_layout_type = [
             k for k in KEY_TO_MODEL_URL if k.startswith("yolov8n")
         ]
+        self.doclayout_type = [k for k in KEY_TO_MODEL_URL if k.startswith("doclayout")]
 
     def __call__(
         self, img_content: Union[str, np.ndarray, bytes, Path]
@@ -91,6 +102,9 @@ def __call__(
         if self.model_type in self.yolov8_layout_type:
             return self.yolov8_layout(img, ori_img_shape)
 
+        if self.model_type in self.doclayout_type:
+            return self.doclayout_layout(img, ori_img_shape)
+
         raise ValueError(f"{self.model_type} is not supported.")
 
     def pp_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]):
@@ -114,6 +128,17 @@ def yolov8_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]):
         elapse = time.time() - s_time
         return boxes, scores, class_names, elapse
 
+    def doclayout_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]):
+        s_time = time.time()
+
+        input_tensor = self.doclayout_preprocess(img)
+        outputs = self.session(input_tensor)
+        boxes, scores, class_names = self.doclayout_postprocess(
+            outputs, ori_img_shape, self.doclayout_input_shape
+        )
+        elapse = time.time() - s_time
+        return boxes, scores, class_names, elapse
+
     @staticmethod
     def get_model_path(model_type: str, model_path: Union[str, Path, None]) -> str:
         if model_path is not None:
diff --git a/rapid_layout/utils/__init__.py b/rapid_layout/utils/__init__.py
index ed27313..f9b7953 100644
--- a/rapid_layout/utils/__init__.py
+++ b/rapid_layout/utils/__init__.py
@@ -5,6 +5,6 @@
 from .infer_engine import OrtInferSession
 from .load_image import LoadImage, LoadImageError
 from .logger import get_logger
-from .post_prepross import PPPostProcess, YOLOv8PostProcess
-from .pre_procss import PPPreProcess, YOLOv8PreProcess
+from .post_prepross import DocLayoutPostProcess, PPPostProcess, YOLOv8PostProcess
+from .pre_procss import DocLayoutPreProcess, PPPreProcess, YOLOv8PreProcess
 from .vis_res import VisLayout
diff --git a/rapid_layout/utils/post_prepross.py b/rapid_layout/utils/post_prepross.py
index 516fc33..bba6916 100644
--- a/rapid_layout/utils/post_prepross.py
+++ b/rapid_layout/utils/post_prepross.py
@@ -288,23 +288,61 @@ def extract_boxes(self, predictions):
         boxes = predictions[:, :4]
 
         # Scale boxes to original image dimensions
-        boxes = self.rescale_boxes(boxes)
+        boxes = rescale_boxes(
+            boxes, self.input_width, self.input_height, self.img_width, self.img_height
+        )
 
         # Convert boxes to xyxy format
         boxes = xywh2xyxy(boxes)
 
         return boxes
 
-    def rescale_boxes(self, boxes):
+
+class DocLayoutPostProcess:
+    def __init__(self, labels: List[str], conf_thres=0.7, iou_thres=0.5):
+        self.labels = labels
+        self.conf_threshold = conf_thres
+        self.iou_threshold = iou_thres
+        self.input_width, self.input_height = None, None
+        self.img_width, self.img_height = None, None
+
+    def __call__(
+        self,
+        output,
+        ori_img_shape: Tuple[int, int],
+        img_shape: Tuple[int, int] = (1024, 1024),
+    ):
+        self.img_height, self.img_width = ori_img_shape
+        self.input_height, self.input_width = img_shape
+
+        output = output[0].squeeze()
+        boxes = output[:, :-2]
+        confidences = output[:, -2]
+        class_ids = output[:, -1].astype(int)
+
+        mask = confidences > self.conf_threshold
+        boxes = boxes[mask, :]
+        confidences = confidences[mask]
+        class_ids = class_ids[mask]
+
         # Rescale boxes to original image dimensions
-        input_shape = np.array(
-            [self.input_width, self.input_height, self.input_width, self.input_height]
-        )
-        boxes = np.divide(boxes, input_shape, dtype=np.float32)
-        boxes *= np.array(
-            [self.img_width, self.img_height, self.img_width, self.img_height]
+        boxes = rescale_boxes(
+            boxes,
+            self.input_width,
+            self.input_height,
+            self.img_width,
+            self.img_height,
         )
-        return boxes
+        labels = [self.labels[i] for i in class_ids]
+        return boxes, confidences, labels
+
+
+def rescale_boxes(boxes, input_width, input_height, img_width, img_height):
+    # Rescale boxes to original image dimensions
+    input_shape = np.array([input_width, input_height, input_width, input_height])
+    boxes = np.divide(boxes, input_shape, dtype=np.float32)
+    boxes *= np.array([img_width, img_height, img_width, img_height])
+    return boxes
 
 
 def nms(boxes, scores, iou_threshold):
diff --git a/rapid_layout/utils/pre_procss.py b/rapid_layout/utils/pre_procss.py
index 78ce748..0e8727a 100644
--- a/rapid_layout/utils/pre_procss.py
+++ b/rapid_layout/utils/pre_procss.py
@@ -51,3 +51,17 @@ def __call__(self, image: np.ndarray) -> np.ndarray:
         input_img = input_img.transpose(2, 0, 1)
         input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
         return input_tensor
+
+
+class DocLayoutPreProcess:
+
+    def __init__(self, img_size: Tuple[int, int]):
+        self.img_size = img_size
+
+    def __call__(self, image: np.ndarray) -> np.ndarray:
+        input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+        input_img = cv2.resize(image, self.img_size)
+        input_img = input_img / 255.0
+        input_img = input_img.transpose(2, 0, 1)
+        input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
+        return input_tensor
diff --git a/tests/test_layout.py b/tests/test_layout.py
index f7f507e..b40713e 100644
--- a/tests/test_layout.py
+++ b/tests/test_layout.py
@@ -22,7 +22,12 @@
 
 
 @pytest.mark.parametrize(
-    "model_type,gt", [("yolov8n_layout_publaynet", 12), ("yolov8n_layout_general6", 13)]
+    "model_type,gt",
+    [
+        ("yolov8n_layout_publaynet", 12),
+        ("yolov8n_layout_general6", 13),
+        ("doclayout_yolo", 14),
+    ],
 )
 def test_yolov8n_layout(model_type, gt):
     img_path = test_file_dir / "PMC3576793_00004.jpg"