diff --git a/README.md b/README.md index 4eaf4a7..deb2be2 100644 --- a/README.md +++ b/README.md @@ -102,12 +102,12 @@ tabled_gui ```python from tabled.extract import extract_tables from tabled.fileinput import load_pdfs_images -from tabled.inference.models import load_detection_models, load_recognition_models +from tabled.inference.models import load_detection_models, load_recognition_models, load_layout_models -det_models, rec_models = load_detection_models(), load_recognition_models() +det_models, rec_models, layout_models = load_detection_models(), load_recognition_models(), load_layout_models() images, highres_images, names, text_lines = load_pdfs_images(IN_PATH) -page_results = extract_tables(images, highres_images, text_lines, det_models, rec_models) +page_results = extract_tables(images, highres_images, text_lines, det_models, layout_models, rec_models) ``` # Benchmarks diff --git a/extract.py b/extract.py index 720ce03..db13d27 100644 --- a/extract.py +++ b/extract.py @@ -10,7 +10,7 @@ from tabled.extract import extract_tables from tabled.formats import formatter from tabled.fileinput import load_pdfs_images -from tabled.inference.models import load_detection_models, load_recognition_models +from tabled.inference.models import load_detection_models, load_recognition_models, load_layout_models @click.command(help="Extract tables from PDFs") @@ -36,8 +36,9 @@ def main(in_path, out_folder, save_json, save_debug_images, skip_detection, dete det_models = load_detection_models() rec_models = load_recognition_models() + layout_models = load_layout_models() - page_results = extract_tables(images, highres_images, text_lines, det_models, rec_models, skip_detection=skip_detection, detect_boxes=detect_cell_boxes) + page_results = extract_tables(images, highres_images, text_lines, det_models, layout_models, rec_models, skip_detection=skip_detection, detect_boxes=detect_cell_boxes) out_json = defaultdict(list) for name, pnum, result in zip(names, pnums, page_results): diff --git a/pyproject.toml b/pyproject.toml index aa71671..78e28c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tabled-pdf" -version = "0.1.6" +version = "0.1.7" description = "Detect and recognize tables in PDFs and images." authors = ["Vik Paruchuri "] readme = "README.md" diff --git a/table_app.py b/table_app.py index 7dc69f7..82aadf3 100644 --- a/table_app.py +++ b/table_app.py @@ -15,24 +15,24 @@ from PIL import Image import streamlit as st -from tabled.inference.models import load_detection_models, load_recognition_models +from tabled.inference.models import load_detection_models, load_recognition_models, load_layout_models @st.cache_resource() def load_models(): - return load_detection_models(), load_recognition_models() + return load_detection_models(), load_recognition_models(), load_layout_models() def run_table_rec(image, highres_image, text_line, models, skip_detection=False, detect_boxes=False): if not skip_detection: - table_imgs, table_bboxes, _ = detect_tables([image], [highres_image], models[0]) + table_imgs, table_bboxes, _ = detect_tables([image], [highres_image], models[2]) else: table_imgs = [highres_image] table_bboxes = [[0, 0, highres_image.size[0], highres_image.size[1]]] table_text_lines = [text_line] * len(table_imgs) highres_image_sizes = [highres_image.size] * len(table_imgs) - cells, needs_ocr = get_cells(table_imgs, table_bboxes, highres_image_sizes, table_text_lines, models[0][:2], detect_boxes=detect_boxes) + cells, needs_ocr = get_cells(table_imgs, table_bboxes, highres_image_sizes, table_text_lines, models[0], detect_boxes=detect_boxes) table_rec = recognize_tables(table_imgs, cells, needs_ocr, models[1]) cells = [assign_rows_columns(tr, im_size) for tr, im_size in zip(table_rec, highres_image_sizes)] diff --git a/tabled/extract.py b/tabled/extract.py index 7284f3b..edd7f4b 100644 --- a/tabled/extract.py +++ b/tabled/extract.py @@ -13,12 +13,13 @@ def extract_tables( highres_images, text_lines, det_models, + layout_models, rec_models, skip_detection=False, detect_boxes=False ) -> List[ExtractPageResult]: if not skip_detection: - table_imgs, table_bboxes, table_counts = detect_tables(images, highres_images, det_models) + table_imgs, table_bboxes, table_counts = detect_tables(images, highres_images, layout_models) else: table_imgs = highres_images table_bboxes = [[0, 0, img.size[0], img.size[1]] for img in highres_images] @@ -30,7 +31,7 @@ def extract_tables( table_text_lines.extend([text_lines[i]] * tc) highres_image_sizes.extend([highres_images[i].size] * tc) - cells, needs_ocr = get_cells(table_imgs, table_bboxes, highres_image_sizes, table_text_lines, det_models[:2], detect_boxes=detect_boxes) + cells, needs_ocr = get_cells(table_imgs, table_bboxes, highres_image_sizes, table_text_lines, det_models, detect_boxes=detect_boxes) table_rec = recognize_tables(table_imgs, cells, needs_ocr, rec_models) cells = [assign_rows_columns(tr, im_size) for tr, im_size in zip(table_rec, highres_image_sizes)] diff --git a/tabled/inference/detection.py b/tabled/inference/detection.py index 67d770f..e33f707 100644 --- a/tabled/inference/detection.py +++ b/tabled/inference/detection.py @@ -1,4 +1,3 @@ -from surya.detection import batch_text_detection from surya.layout import batch_layout_detection from surya.postprocessing.util import rescale_bbox from surya.schema import Bbox diff --git a/tabled/inference/models.py b/tabled/inference/models.py index 459b4d7..6144260 100644 --- a/tabled/inference/models.py +++ b/tabled/inference/models.py @@ -8,9 +8,9 @@ def load_detection_models(): - layout_model = load_layout_model() - layout_processor = load_layout_processor() - return layout_model, layout_processor + detection_model = load_det_model() + detection_processor = load_det_processor() + return detection_model, detection_processor def load_recognition_models(): @@ -18,4 +18,9 @@ def load_recognition_models(): table_rec_processor = load_table_rec_processor() rec_model = load_rec_model() rec_processor = load_rec_processor() - return table_rec_model, table_rec_processor, rec_model, rec_processor \ No newline at end of file + return table_rec_model, table_rec_processor, rec_model, rec_processor + +def load_layout_models(): + layout_model = load_layout_model() + layout_processor = load_layout_processor() + return layout_model, layout_processor \ No newline at end of file